MLIR  21.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Verifier.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Pass/PassRegistry.h"
33 #include "mlir/Transforms/CSE.h"
37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include "llvm/Support/InterleavedRange.h"
45 #include <optional>
46 
47 #define DEBUG_TYPE "transform-dialect"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
49 
50 #define DEBUG_TYPE_MATCHER "transform-matcher"
51 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
52 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
53 
54 using namespace mlir;
55 
56 static ParseResult parseSequenceOpOperands(
57  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
58  Type &rootType,
59  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
60  SmallVectorImpl<Type> &extraBindingTypes);
61 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
62  Value root, Type rootType,
63  ValueRange extraBindings,
64  TypeRange extraBindingTypes);
65 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
66  ArrayAttr matchers, ArrayAttr actions);
67 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
68  ArrayAttr &matchers,
69  ArrayAttr &actions);
70 
71 /// Helper function to check if the given transform op is contained in (or
72 /// equal to) the given payload target op. In that case, an error is returned.
73 /// Transforming transform IR that is currently executing is generally unsafe.
75 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
76  Operation *payload) {
77  Operation *transformAncestor = transform.getOperation();
78  while (transformAncestor) {
79  if (transformAncestor == payload) {
81  transform.emitDefiniteFailure()
82  << "cannot apply transform to itself (or one of its ancestors)";
83  diag.attachNote(payload->getLoc()) << "target payload op";
84  return diag;
85  }
86  transformAncestor = transformAncestor->getParentOp();
87  }
89 }
90 
91 #define GET_OP_CLASSES
92 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
93 
94 //===----------------------------------------------------------------------===//
95 // AlternativesOp
96 //===----------------------------------------------------------------------===//
97 
99 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
100  if (!point.isParent() && getOperation()->getNumOperands() == 1)
101  return getOperation()->getOperands();
102  return OperandRange(getOperation()->operand_end(),
103  getOperation()->operand_end());
104 }
105 
106 void transform::AlternativesOp::getSuccessorRegions(
107  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
108  for (Region &alternative : llvm::drop_begin(
109  getAlternatives(),
110  point.isParent() ? 0
111  : point.getRegionOrNull()->getRegionNumber() + 1)) {
112  regions.emplace_back(&alternative, !getOperands().empty()
113  ? alternative.getArguments()
115  }
116  if (!point.isParent())
117  regions.emplace_back(getOperation()->getResults());
118 }
119 
120 void transform::AlternativesOp::getRegionInvocationBounds(
121  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
122  (void)operands;
123  // The region corresponding to the first alternative is always executed, the
124  // remaining may or may not be executed.
125  bounds.reserve(getNumRegions());
126  bounds.emplace_back(1, 1);
127  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
128 }
129 
131  transform::TransformResults &results) {
132  for (const auto &res : block->getParentOp()->getOpResults())
133  results.set(res, {});
134 }
135 
137 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
139  transform::TransformState &state) {
140  SmallVector<Operation *> originals;
141  if (Value scopeHandle = getScope())
142  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
143  else
144  originals.push_back(state.getTopLevel());
145 
146  for (Operation *original : originals) {
147  if (original->isAncestor(getOperation())) {
148  auto diag = emitDefiniteFailure()
149  << "scope must not contain the transforms being applied";
150  diag.attachNote(original->getLoc()) << "scope";
151  return diag;
152  }
153  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
154  auto diag = emitDefiniteFailure()
155  << "only isolated-from-above ops can be alternative scopes";
156  diag.attachNote(original->getLoc()) << "scope";
157  return diag;
158  }
159  }
160 
161  for (Region &reg : getAlternatives()) {
162  // Clone the scope operations and make the transforms in this alternative
163  // region apply to them by virtue of mapping the block argument (the only
164  // visible handle) to the cloned scope operations. This effectively prevents
165  // the transformation from accessing any IR outside the scope.
166  auto scope = state.make_region_scope(reg);
167  auto clones = llvm::to_vector(
168  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
169  auto deleteClones = llvm::make_scope_exit([&] {
170  for (Operation *clone : clones)
171  clone->erase();
172  });
173  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
175 
176  bool failed = false;
177  for (Operation &transform : reg.front().without_terminator()) {
179  state.applyTransform(cast<TransformOpInterface>(transform));
180  if (result.isSilenceableFailure()) {
181  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
182  << "\n");
183  failed = true;
184  break;
185  }
186 
187  if (::mlir::failed(result.silence()))
189  }
190 
191  // If all operations in the given alternative succeeded, no need to consider
192  // the rest. Replace the original scoping operation with the clone on which
193  // the transformations were performed.
194  if (!failed) {
195  // We will be using the clones, so cancel their scheduled deletion.
196  deleteClones.release();
197  TrackingListener listener(state, *this);
198  IRRewriter rewriter(getContext(), &listener);
199  for (const auto &kvp : llvm::zip(originals, clones)) {
200  Operation *original = std::get<0>(kvp);
201  Operation *clone = std::get<1>(kvp);
202  original->getBlock()->getOperations().insert(original->getIterator(),
203  clone);
204  rewriter.replaceOp(original, clone->getResults());
205  }
206  detail::forwardTerminatorOperands(&reg.front(), state, results);
208  }
209  }
210  return emitSilenceableError() << "all alternatives failed";
211 }
212 
213 void transform::AlternativesOp::getEffects(
214  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215  consumesHandle(getOperation()->getOpOperands(), effects);
216  producesHandle(getOperation()->getOpResults(), effects);
217  for (Region *region : getRegions()) {
218  if (!region->empty())
219  producesHandle(region->front().getArguments(), effects);
220  }
221  modifiesPayload(effects);
222 }
223 
224 LogicalResult transform::AlternativesOp::verify() {
225  for (Region &alternative : getAlternatives()) {
226  Block &block = alternative.front();
227  Operation *terminator = block.getTerminator();
228  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
229  InFlightDiagnostic diag = emitOpError()
230  << "expects terminator operands to have the "
231  "same type as results of the operation";
232  diag.attachNote(terminator->getLoc()) << "terminator";
233  return diag;
234  }
235  }
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AnnotateOp
242 //===----------------------------------------------------------------------===//
243 
245 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
247  transform::TransformState &state) {
248  SmallVector<Operation *> targets =
249  llvm::to_vector(state.getPayloadOps(getTarget()));
250 
252  if (auto paramH = getParam()) {
253  ArrayRef<Attribute> params = state.getParams(paramH);
254  if (params.size() != 1) {
255  if (targets.size() != params.size()) {
256  return emitSilenceableError()
257  << "parameter and target have different payload lengths ("
258  << params.size() << " vs " << targets.size() << ")";
259  }
260  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
261  target->setAttr(getName(), attr);
263  }
264  attr = params[0];
265  }
266  for (auto *target : targets)
267  target->setAttr(getName(), attr);
269 }
270 
271 void transform::AnnotateOp::getEffects(
272  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
273  onlyReadsHandle(getTargetMutable(), effects);
274  onlyReadsHandle(getParamMutable(), effects);
275  modifiesPayload(effects);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ApplyCommonSubexpressionEliminationOp
280 //===----------------------------------------------------------------------===//
281 
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
284  transform::TransformRewriter &rewriter, Operation *target,
285  ApplyToEachResultList &results, transform::TransformState &state) {
286  // Make sure that this transform is not applied to itself. Modifying the
287  // transform IR while it is being interpreted is generally dangerous.
288  DiagnosedSilenceableFailure payloadCheck =
290  if (!payloadCheck.succeeded())
291  return payloadCheck;
292 
293  DominanceInfo domInfo;
294  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
296 }
297 
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
300  transform::onlyReadsHandle(getTargetMutable(), effects);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ApplyDeadCodeEliminationOp
306 //===----------------------------------------------------------------------===//
307 
308 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
309  transform::TransformRewriter &rewriter, Operation *target,
310  ApplyToEachResultList &results, transform::TransformState &state) {
311  // Make sure that this transform is not applied to itself. Modifying the
312  // transform IR while it is being interpreted is generally dangerous.
313  DiagnosedSilenceableFailure payloadCheck =
315  if (!payloadCheck.succeeded())
316  return payloadCheck;
317 
318  // Maintain a worklist of potentially dead ops.
319  SetVector<Operation *> worklist;
320 
321  // Helper function that adds all defining ops of used values (operands and
322  // operands of nested ops).
323  auto addDefiningOpsToWorklist = [&](Operation *op) {
324  op->walk([&](Operation *op) {
325  for (Value v : op->getOperands())
326  if (Operation *defOp = v.getDefiningOp())
327  if (target->isProperAncestor(defOp))
328  worklist.insert(defOp);
329  });
330  };
331 
332  // Helper function that erases an op.
333  auto eraseOp = [&](Operation *op) {
334  // Remove op and nested ops from the worklist.
335  op->walk([&](Operation *op) {
336  const auto *it = llvm::find(worklist, op);
337  if (it != worklist.end())
338  worklist.erase(it);
339  });
340  rewriter.eraseOp(op);
341  };
342 
343  // Initial walk over the IR.
344  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
345  if (op != target && isOpTriviallyDead(op)) {
346  addDefiningOpsToWorklist(op);
347  eraseOp(op);
348  }
349  });
350 
351  // Erase all ops that have become dead.
352  while (!worklist.empty()) {
353  Operation *op = worklist.pop_back_val();
354  if (!isOpTriviallyDead(op))
355  continue;
356  addDefiningOpsToWorklist(op);
357  eraseOp(op);
358  }
359 
361 }
362 
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
365  transform::onlyReadsHandle(getTargetMutable(), effects);
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // ApplyPatternsOp
371 //===----------------------------------------------------------------------===//
372 
373 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
374  transform::TransformRewriter &rewriter, Operation *target,
375  ApplyToEachResultList &results, transform::TransformState &state) {
376  // Make sure that this transform is not applied to itself. Modifying the
377  // transform IR while it is being interpreted is generally dangerous. Even
378  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
379  // performs many additional simplifications such as dead code elimination.
380  DiagnosedSilenceableFailure payloadCheck =
382  if (!payloadCheck.succeeded())
383  return payloadCheck;
384 
385  // Gather all specified patterns.
386  MLIRContext *ctx = target->getContext();
388  if (!getRegion().empty()) {
389  for (Operation &op : getRegion().front()) {
390  cast<transform::PatternDescriptorOpInterface>(&op)
391  .populatePatternsWithState(patterns, state);
392  }
393  }
394 
395  // Configure the GreedyPatternRewriteDriver.
397  config.listener =
398  static_cast<RewriterBase::Listener *>(rewriter.getListener());
399  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
400 
401  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
403  : getMaxIterations();
404  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
406  : getMaxNumRewrites();
407 
408  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
409  // was requested, apply the greedy pattern rewrite only once. (The greedy
410  // pattern rewrite driver already iterates to a fixpoint internally.)
411  bool cseChanged = false;
412  // One or two iterations should be sufficient. Stop iterating after a certain
413  // threshold to make debugging easier.
414  static const int64_t kNumMaxIterations = 50;
415  int64_t iteration = 0;
416  do {
417  LogicalResult result = failure();
418  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
419  // Op is isolated from above. Apply patterns and also perform region
420  // simplification.
421  result = applyPatternsGreedily(target, frozenPatterns, config);
422  } else {
423  // Manually gather list of ops because the other
424  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
425  // from above. This way, patterns can be applied to ops that are not
426  // isolated from above. Regions are not being simplified. Furthermore,
427  // only a single greedy rewrite iteration is performed.
429  target->walk([&](Operation *nestedOp) {
430  if (target != nestedOp)
431  ops.push_back(nestedOp);
432  });
433  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
434  }
435 
436  // A failure typically indicates that the pattern application did not
437  // converge.
438  if (failed(result)) {
439  return emitSilenceableFailure(target)
440  << "greedy pattern application failed";
441  }
442 
443  if (getApplyCse()) {
444  DominanceInfo domInfo;
445  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
446  &cseChanged);
447  }
448  } while (cseChanged && ++iteration < kNumMaxIterations);
449 
450  if (iteration == kNumMaxIterations)
451  return emitDefiniteFailure() << "fixpoint iteration did not converge";
452 
454 }
455 
456 LogicalResult transform::ApplyPatternsOp::verify() {
457  if (!getRegion().empty()) {
458  for (Operation &op : getRegion().front()) {
459  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
460  InFlightDiagnostic diag = emitOpError()
461  << "expected children ops to implement "
462  "PatternDescriptorOpInterface";
463  diag.attachNote(op.getLoc()) << "op without interface";
464  return diag;
465  }
466  }
467  }
468  return success();
469 }
470 
471 void transform::ApplyPatternsOp::getEffects(
472  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
473  transform::onlyReadsHandle(getTargetMutable(), effects);
475 }
476 
477 void transform::ApplyPatternsOp::build(
478  OpBuilder &builder, OperationState &result, Value target,
479  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
480  result.addOperands(target);
481 
482  OpBuilder::InsertionGuard g(builder);
483  Region *region = result.addRegion();
484  builder.createBlock(region);
485  if (bodyBuilder)
486  bodyBuilder(builder, result.location);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // ApplyCanonicalizationPatternsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
495  MLIRContext *ctx = patterns.getContext();
496  for (Dialect *dialect : ctx->getLoadedDialects())
497  dialect->getCanonicalizationPatterns(patterns);
499  op.getCanonicalizationPatterns(patterns, ctx);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // ApplyConversionPatternsOp
504 //===----------------------------------------------------------------------===//
505 
506 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
509  MLIRContext *ctx = getContext();
510 
511  // Instantiate the default type converter if a type converter builder is
512  // specified.
513  std::unique_ptr<TypeConverter> defaultTypeConverter;
514  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515  getDefaultTypeConverter();
516  if (typeConverterBuilder)
517  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
518 
519  // Configure conversion target.
520  ConversionTarget conversionTarget(*getContext());
521  if (getLegalOps())
522  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523  conversionTarget.addLegalOp(
524  OperationName(cast<StringAttr>(attr).getValue(), ctx));
525  if (getIllegalOps())
526  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527  conversionTarget.addIllegalOp(
528  OperationName(cast<StringAttr>(attr).getValue(), ctx));
529  if (getLegalDialects())
530  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532  if (getIllegalDialects())
533  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
535 
536  // Gather all specified patterns.
538  // Need to keep the converters alive until after pattern application because
539  // the patterns take a reference to an object that would otherwise get out of
540  // scope.
541  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
542  if (!getPatterns().empty()) {
543  for (Operation &op : getPatterns().front()) {
544  auto descriptor =
545  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
546 
547  // Check if this pattern set specifies a type converter.
548  std::unique_ptr<TypeConverter> typeConverter =
549  descriptor.getTypeConverter();
550  TypeConverter *converter = nullptr;
551  if (typeConverter) {
552  keepAliveConverters.emplace_back(std::move(typeConverter));
553  converter = keepAliveConverters.back().get();
554  } else {
555  // No type converter specified: Use the default type converter.
556  if (!defaultTypeConverter) {
557  auto diag = emitDefiniteFailure()
558  << "pattern descriptor does not specify type "
559  "converter and apply_conversion_patterns op has "
560  "no default type converter";
561  diag.attachNote(op.getLoc()) << "pattern descriptor op";
562  return diag;
563  }
564  converter = defaultTypeConverter.get();
565  }
566 
567  // Add descriptor-specific updates to the conversion target, which may
568  // depend on the final type converter. In structural converters, the
569  // legality of types dictates the dynamic legality of an operation.
570  descriptor.populateConversionTargetRules(*converter, conversionTarget);
571 
572  descriptor.populatePatterns(*converter, patterns);
573  }
574  }
575 
576  // Attach a tracking listener if handles should be preserved. We configure the
577  // listener to allow op replacements with different names, as conversion
578  // patterns typically replace ops with replacement ops that have a different
579  // name.
580  TrackingListenerConfig trackingConfig;
581  trackingConfig.requireMatchingReplacementOpName = false;
582  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
583  ConversionConfig conversionConfig;
584  if (getPreserveHandles())
585  conversionConfig.listener = &trackingListener;
586 
587  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
588  for (Operation *target : state.getPayloadOps(getTarget())) {
589  // Make sure that this transform is not applied to itself. Modifying the
590  // transform IR while it is being interpreted is generally dangerous.
591  DiagnosedSilenceableFailure payloadCheck =
593  if (!payloadCheck.succeeded())
594  return payloadCheck;
595 
596  LogicalResult status = failure();
597  if (getPartialConversion()) {
598  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
599  conversionConfig);
600  } else {
601  status = applyFullConversion(target, conversionTarget, frozenPatterns,
602  conversionConfig);
603  }
604 
605  // Check dialect conversion state.
607  if (failed(status)) {
608  diag = emitSilenceableError() << "dialect conversion failed";
609  diag.attachNote(target->getLoc()) << "target op";
610  }
611 
612  // Check tracking listener error state.
613  DiagnosedSilenceableFailure trackingFailure =
614  trackingListener.checkAndResetError();
615  if (!trackingFailure.succeeded()) {
616  if (diag.succeeded()) {
617  // Tracking failure is the only failure.
618  return trackingFailure;
619  } else {
620  diag.attachNote() << "tracking listener also failed: "
621  << trackingFailure.getMessage();
622  (void)trackingFailure.silence();
623  }
624  }
625 
626  if (!diag.succeeded())
627  return diag;
628  }
629 
631 }
632 
634  if (getNumRegions() != 1 && getNumRegions() != 2)
635  return emitOpError() << "expected 1 or 2 regions";
636  if (!getPatterns().empty()) {
637  for (Operation &op : getPatterns().front()) {
638  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
640  emitOpError() << "expected pattern children ops to implement "
641  "ConversionPatternDescriptorOpInterface";
642  diag.attachNote(op.getLoc()) << "op without interface";
643  return diag;
644  }
645  }
646  }
647  if (getNumRegions() == 2) {
648  Region &typeConverterRegion = getRegion(1);
649  if (!llvm::hasSingleElement(typeConverterRegion.front()))
650  return emitOpError()
651  << "expected exactly one op in default type converter region";
652  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
653  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
654  maybeTypeConverter);
655  if (!typeConverterOp) {
656  InFlightDiagnostic diag = emitOpError()
657  << "expected default converter child op to "
658  "implement TypeConverterBuilderOpInterface";
659  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
660  return diag;
661  }
662  // Check default type converter type.
663  if (!getPatterns().empty()) {
664  for (Operation &op : getPatterns().front()) {
665  auto descriptor =
666  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
667  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
668  return failure();
669  }
670  }
671  }
672  return success();
673 }
674 
675 void transform::ApplyConversionPatternsOp::getEffects(
676  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
677  if (!getPreserveHandles()) {
678  transform::consumesHandle(getTargetMutable(), effects);
679  } else {
680  transform::onlyReadsHandle(getTargetMutable(), effects);
681  }
683 }
684 
685 void transform::ApplyConversionPatternsOp::build(
686  OpBuilder &builder, OperationState &result, Value target,
687  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
688  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
689  result.addOperands(target);
690 
691  {
692  OpBuilder::InsertionGuard g(builder);
693  Region *region1 = result.addRegion();
694  builder.createBlock(region1);
695  if (patternsBodyBuilder)
696  patternsBodyBuilder(builder, result.location);
697  }
698  {
699  OpBuilder::InsertionGuard g(builder);
700  Region *region2 = result.addRegion();
701  builder.createBlock(region2);
702  if (typeConverterBodyBuilder)
703  typeConverterBodyBuilder(builder, result.location);
704  }
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // ApplyToLLVMConversionPatternsOp
709 //===----------------------------------------------------------------------===//
710 
711 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
712  TypeConverter &typeConverter, RewritePatternSet &patterns) {
713  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
714  assert(dialect && "expected that dialect is loaded");
715  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
716  // ConversionTarget is currently ignored because the enclosing
717  // apply_conversion_patterns op sets up its own ConversionTarget.
718  ConversionTarget target(*getContext());
719  iface->populateConvertToLLVMConversionPatterns(
720  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
721 }
722 
723 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
724  transform::TypeConverterBuilderOpInterface builder) {
725  if (builder.getTypeConverterType() != "LLVMTypeConverter")
726  return emitOpError("expected LLVMTypeConverter");
727  return success();
728 }
729 
731  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
732  if (!dialect)
733  return emitOpError("unknown dialect or dialect not loaded: ")
734  << getDialectName();
735  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
736  if (!iface)
737  return emitOpError(
738  "dialect does not implement ConvertToLLVMPatternInterface or "
739  "extension was not loaded: ")
740  << getDialectName();
741  return success();
742 }
743 
744 //===----------------------------------------------------------------------===//
745 // ApplyLoopInvariantCodeMotionOp
746 //===----------------------------------------------------------------------===//
747 
749 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
750  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
752  transform::TransformState &state) {
753  // Currently, LICM does not remove operations, so we don't need tracking.
754  // If this ever changes, add a LICM entry point that takes a rewriter.
755  moveLoopInvariantCode(target);
757 }
758 
759 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
760  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
761  transform::onlyReadsHandle(getTargetMutable(), effects);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // ApplyRegisteredPassOp
767 //===----------------------------------------------------------------------===//
768 
769 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
770  transform::TransformRewriter &rewriter, Operation *target,
771  ApplyToEachResultList &results, transform::TransformState &state) {
772  // Make sure that this transform is not applied to itself. Modifying the
773  // transform IR while it is being interpreted is generally dangerous. Even
774  // more so when applying passes because they may perform a wide range of IR
775  // modifications.
776  DiagnosedSilenceableFailure payloadCheck =
778  if (!payloadCheck.succeeded())
779  return payloadCheck;
780 
781  // Get pass or pass pipeline from registry.
782  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
783  if (!info)
784  info = PassInfo::lookup(getPassName());
785  if (!info)
786  return emitDefiniteFailure()
787  << "unknown pass or pass pipeline: " << getPassName();
788 
789  // Create pass manager and run the pass or pass pipeline.
790  PassManager pm(getContext());
791  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
792  emitError(msg);
793  return failure();
794  }))) {
795  return emitDefiniteFailure()
796  << "failed to add pass or pass pipeline to pipeline: "
797  << getPassName();
798  }
799  if (failed(pm.run(target))) {
800  auto diag = emitSilenceableError() << "pass pipeline failed";
801  diag.attachNote(target->getLoc()) << "target op";
802  return diag;
803  }
804 
805  results.push_back(target);
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // CastOp
811 //===----------------------------------------------------------------------===//
812 
814 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
815  Operation *target, ApplyToEachResultList &results,
816  transform::TransformState &state) {
817  results.push_back(target);
819 }
820 
821 void transform::CastOp::getEffects(
822  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
823  onlyReadsPayload(effects);
824  onlyReadsHandle(getInputMutable(), effects);
825  producesHandle(getOperation()->getOpResults(), effects);
826 }
827 
828 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
829  assert(inputs.size() == 1 && "expected one input");
830  assert(outputs.size() == 1 && "expected one output");
831  return llvm::all_of(
832  std::initializer_list<Type>{inputs.front(), outputs.front()},
833  llvm::IsaPred<transform::TransformHandleTypeInterface>);
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // CollectMatchingOp
838 //===----------------------------------------------------------------------===//
839 
840 /// Applies matcher operations from the given `block` using
841 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
842 /// accordingly. If any of the matcher produces a silenceable failure, discards
843 /// it (printing the content to the debug output stream) and returns failure. If
844 /// any of the matchers produces a definite failure, reports it and returns
845 /// failure. If all matchers in the block succeed, populates `mappings` with the
846 /// payload entities associated with the block terminator operands. Note that
847 /// `mappings` will be cleared before that.
850  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
852  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
853  assert(block.getParent() && "cannot match using a detached block");
854  auto matchScope = state.make_region_scope(*block.getParent());
855  if (failed(
856  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
858 
859  for (Operation &match : block.without_terminator()) {
860  if (!isa<transform::MatchOpInterface>(match)) {
861  return emitDefiniteFailure(match.getLoc())
862  << "expected operations in the match part to "
863  "implement MatchOpInterface";
864  }
866  state.applyTransform(cast<transform::TransformOpInterface>(match));
867  if (diag.succeeded())
868  continue;
869 
870  return diag;
871  }
872 
873  // Remember the values mapped to the terminator operands so we can
874  // forward them to the action.
875  ValueRange yieldedValues = block.getTerminator()->getOperands();
876  // Our contract with the caller is that the mappings will contain only the
877  // newly mapped values, clear the rest.
878  mappings.clear();
879  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
881 }
882 
883 /// Returns `true` if both types implement one of the interfaces provided as
884 /// template parameters.
885 template <typename... Tys>
886 static bool implementSameInterface(Type t1, Type t2) {
887  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
888 }
889 
890 /// Returns `true` if both types implement one of the transform dialect
891 /// interfaces.
893  return implementSameInterface<transform::TransformHandleTypeInterface,
894  transform::TransformParamTypeInterface,
895  transform::TransformValueHandleTypeInterface>(
896  t1, t2);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // CollectMatchingOp
901 //===----------------------------------------------------------------------===//
902 
904 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
906  transform::TransformState &state) {
907  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
908  getOperation(), getMatcher());
909  if (matcher.isExternal()) {
910  return emitDefiniteFailure()
911  << "unresolved external symbol " << getMatcher();
912  }
913 
915  rawResults.resize(getOperation()->getNumResults());
916  std::optional<DiagnosedSilenceableFailure> maybeFailure;
917  for (Operation *root : state.getPayloadOps(getRoot())) {
918  WalkResult walkResult = root->walk([&](Operation *op) {
919  DEBUG_MATCHER({
920  DBGS_MATCHER() << "matching ";
921  op->print(llvm::dbgs(),
922  OpPrintingFlags().assumeVerified().skipRegions());
923  llvm::dbgs() << " @" << op << "\n";
924  });
925 
926  // Try matching.
928  SmallVector<transform::MappedValue> inputMapping({op});
930  matcher.getFunctionBody().front(),
931  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
932  mappings);
933  if (diag.isDefiniteFailure())
934  return WalkResult::interrupt();
935  if (diag.isSilenceableFailure()) {
936  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
937  << " failed: " << diag.getMessage());
938  return WalkResult::advance();
939  }
940 
941  // If succeeded, collect results.
942  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
943  if (mapping.size() != 1) {
944  maybeFailure.emplace(emitSilenceableError()
945  << "result #" << i << ", associated with "
946  << mapping.size()
947  << " payload objects, expected 1");
948  return WalkResult::interrupt();
949  }
950  rawResults[i].push_back(mapping[0]);
951  }
952  return WalkResult::advance();
953  });
954  if (walkResult.wasInterrupted())
955  return std::move(*maybeFailure);
956  assert(!maybeFailure && "failure set but the walk was not interrupted");
957 
958  for (auto &&[opResult, rawResult] :
959  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
960  results.setMappedValues(opResult, rawResult);
961  }
962  }
964 }
965 
966 void transform::CollectMatchingOp::getEffects(
967  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
968  onlyReadsHandle(getRootMutable(), effects);
969  producesHandle(getOperation()->getOpResults(), effects);
970  onlyReadsPayload(effects);
971 }
972 
973 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
974  SymbolTableCollection &symbolTable) {
975  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
976  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
977  if (!matcherSymbol ||
978  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
979  return emitError() << "unresolved matcher symbol " << getMatcher();
980 
981  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
982  if (argumentTypes.size() != 1 ||
983  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
984  return emitError()
985  << "expected the matcher to take one operation handle argument";
986  }
987  if (!matcherSymbol.getArgAttr(
988  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
989  return emitError() << "expected the matcher argument to be marked readonly";
990  }
991 
992  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
993  if (resultTypes.size() != getOperation()->getNumResults()) {
994  return emitError()
995  << "expected the matcher to yield as many values as op has results ("
996  << getOperation()->getNumResults() << "), got "
997  << resultTypes.size();
998  }
999 
1000  for (auto &&[i, matcherType, resultType] :
1001  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1002  if (implementSameTransformInterface(matcherType, resultType))
1003  continue;
1004 
1005  return emitError()
1006  << "mismatching type interfaces for matcher result and op result #"
1007  << i;
1008  }
1009 
1010  return success();
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // ForeachMatchOp
1015 //===----------------------------------------------------------------------===//
1016 
1017 // This is fine because nothing is actually consumed by this op.
1018 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1019 
1021 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1022  transform::TransformResults &results,
1023  transform::TransformState &state) {
1025  matchActionPairs;
1026  matchActionPairs.reserve(getMatchers().size());
1027  SymbolTableCollection symbolTable;
1028  for (auto &&[matcher, action] :
1029  llvm::zip_equal(getMatchers(), getActions())) {
1030  auto matcherSymbol =
1031  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1032  getOperation(), cast<SymbolRefAttr>(matcher));
1033  auto actionSymbol =
1034  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1035  getOperation(), cast<SymbolRefAttr>(action));
1036  assert(matcherSymbol && actionSymbol &&
1037  "unresolved symbols not caught by the verifier");
1038 
1039  if (matcherSymbol.isExternal())
1040  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1041  if (actionSymbol.isExternal())
1042  return emitDefiniteFailure() << "unresolved external symbol " << action;
1043 
1044  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1045  }
1046 
1047  DiagnosedSilenceableFailure overallDiag =
1049 
1050  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1051  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1052  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1053  // Explicitly add the mapping for the first block argument (the op being
1054  // matched).
1055  matchInputMapping.emplace_back();
1056  transform::detail::prepareValueMappings(matchInputMapping,
1057  getForwardedInputs(), state);
1058  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1059  actionResultMapping.resize(getForwardedOutputs().size());
1060 
1061  for (Operation *root : state.getPayloadOps(getRoot())) {
1062  WalkResult walkResult = root->walk([&](Operation *op) {
1063  // If getRestrictRoot is not present, skip over the root op itself so we
1064  // don't invalidate it.
1065  if (!getRestrictRoot() && op == root)
1066  return WalkResult::advance();
1067 
1068  DEBUG_MATCHER({
1069  DBGS_MATCHER() << "matching ";
1070  op->print(llvm::dbgs(),
1071  OpPrintingFlags().assumeVerified().skipRegions());
1072  llvm::dbgs() << " @" << op << "\n";
1073  });
1074 
1075  firstMatchArgument.clear();
1076  firstMatchArgument.push_back(op);
1077 
1078  // Try all the match/action pairs until the first successful match.
1079  for (auto [matcher, action] : matchActionPairs) {
1081  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1082  state, matchOutputMapping);
1083  if (diag.isDefiniteFailure())
1084  return WalkResult::interrupt();
1085  if (diag.isSilenceableFailure()) {
1086  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1087  << " failed: " << diag.getMessage());
1088  continue;
1089  }
1090 
1091  auto scope = state.make_region_scope(action.getFunctionBody());
1092  if (failed(state.mapBlockArguments(
1093  action.getFunctionBody().front().getArguments(),
1094  matchOutputMapping))) {
1095  return WalkResult::interrupt();
1096  }
1097 
1098  for (Operation &transform :
1099  action.getFunctionBody().front().without_terminator()) {
1101  state.applyTransform(cast<TransformOpInterface>(transform));
1102  if (result.isDefiniteFailure())
1103  return WalkResult::interrupt();
1104  if (result.isSilenceableFailure()) {
1105  if (overallDiag.succeeded()) {
1106  overallDiag = emitSilenceableError() << "actions failed";
1107  }
1108  overallDiag.attachNote(action->getLoc())
1109  << "failed action: " << result.getMessage();
1110  overallDiag.attachNote(op->getLoc())
1111  << "when applied to this matching payload";
1112  (void)result.silence();
1113  continue;
1114  }
1115  }
1116  if (failed(detail::appendValueMappings(
1117  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1118  action.getFunctionBody().front().getTerminator()->getOperands(),
1119  state, getFlattenResults()))) {
1121  << "action @" << action.getName()
1122  << " has results associated with multiple payload entities, "
1123  "but flattening was not requested";
1124  return WalkResult::interrupt();
1125  }
1126  break;
1127  }
1128  return WalkResult::advance();
1129  });
1130  if (walkResult.wasInterrupted())
1132  }
1133 
1134  // The root operation should not have been affected, so we can just reassign
1135  // the payload to the result. Note that we need to consume the root handle to
1136  // make sure any handles to operations inside, that could have been affected
1137  // by actions, are invalidated.
1138  results.set(llvm::cast<OpResult>(getUpdated()),
1139  state.getPayloadOps(getRoot()));
1140  for (auto &&[result, mapping] :
1141  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1142  results.setMappedValues(result, mapping);
1143  }
1144  return overallDiag;
1145 }
1146 
1147 void transform::ForeachMatchOp::getAsmResultNames(
1148  OpAsmSetValueNameFn setNameFn) {
1149  setNameFn(getUpdated(), "updated_root");
1150  for (Value v : getForwardedOutputs()) {
1151  setNameFn(v, "yielded");
1152  }
1153 }
1154 
1155 void transform::ForeachMatchOp::getEffects(
1156  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1157  // Bail if invalid.
1158  if (getOperation()->getNumOperands() < 1 ||
1159  getOperation()->getNumResults() < 1) {
1160  return modifiesPayload(effects);
1161  }
1162 
1163  consumesHandle(getRootMutable(), effects);
1164  onlyReadsHandle(getForwardedInputsMutable(), effects);
1165  producesHandle(getOperation()->getOpResults(), effects);
1166  modifiesPayload(effects);
1167 }
1168 
1169 /// Parses the comma-separated list of symbol reference pairs of the format
1170 /// `@matcher -> @action`.
1171 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1172  ArrayAttr &matchers,
1173  ArrayAttr &actions) {
1174  StringAttr matcher;
1175  StringAttr action;
1176  SmallVector<Attribute> matcherList;
1177  SmallVector<Attribute> actionList;
1178  do {
1179  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1180  parser.parseSymbolName(action)) {
1181  return failure();
1182  }
1183  matcherList.push_back(SymbolRefAttr::get(matcher));
1184  actionList.push_back(SymbolRefAttr::get(action));
1185  } while (parser.parseOptionalComma().succeeded());
1186 
1187  matchers = parser.getBuilder().getArrayAttr(matcherList);
1188  actions = parser.getBuilder().getArrayAttr(actionList);
1189  return success();
1190 }
1191 
1192 /// Prints the comma-separated list of symbol reference pairs of the format
1193 /// `@matcher -> @action`.
1195  ArrayAttr matchers, ArrayAttr actions) {
1196  printer.increaseIndent();
1197  printer.increaseIndent();
1198  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1199  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1200  printer.printNewline();
1201  printer << cast<SymbolRefAttr>(matcher) << " -> "
1202  << cast<SymbolRefAttr>(action);
1203  if (idx != matchers.size() - 1)
1204  printer << ", ";
1205  }
1206  printer.decreaseIndent();
1207  printer.decreaseIndent();
1208 }
1209 
1210 LogicalResult transform::ForeachMatchOp::verify() {
1211  if (getMatchers().size() != getActions().size())
1212  return emitOpError() << "expected the same number of matchers and actions";
1213  if (getMatchers().empty())
1214  return emitOpError() << "expected at least one match/action pair";
1215 
1216  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1217  for (Attribute name : getMatchers()) {
1218  if (matcherNames.insert(name).second)
1219  continue;
1220  emitWarning() << "matcher " << name
1221  << " is used more than once, only the first match will apply";
1222  }
1223 
1224  return success();
1225 }
1226 
1227 /// Checks that the attributes of the function-like operation have correct
1228 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1229 /// annotations being present even if they can be inferred from the body.
1231 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1232  bool alsoVerifyInternal = false) {
1233  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1234  llvm::SmallDenseSet<unsigned> consumedArguments;
1235  if (!op.isExternal()) {
1236  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1237  consumedArguments);
1238  }
1239  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1240  bool isConsumed =
1241  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1242  nullptr;
1243  bool isReadOnly =
1244  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1245  nullptr;
1246  if (isConsumed && isReadOnly) {
1247  return transformOp.emitSilenceableError()
1248  << "argument #" << i << " cannot be both readonly and consumed";
1249  }
1250  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1251  return transformOp.emitSilenceableError()
1252  << "must provide consumed/readonly status for arguments of "
1253  "external or called ops";
1254  }
1255  if (op.isExternal())
1256  continue;
1257 
1258  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1259  return transformOp.emitSilenceableError()
1260  << "argument #" << i
1261  << " is consumed in the body but is not marked as such";
1262  }
1263  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1264  // Cannot use op.emitWarning() here as it would attempt to verify the op
1265  // before printing, resulting in infinite recursion.
1266  emitWarning(op->getLoc())
1267  << "op argument #" << i
1268  << " is not consumed in the body but is marked as consumed";
1269  }
1270  }
1272 }
1273 
1274 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1275  SymbolTableCollection &symbolTable) {
1276  assert(getMatchers().size() == getActions().size());
1277  auto consumedAttr =
1278  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1279  for (auto &&[matcher, action] :
1280  llvm::zip_equal(getMatchers(), getActions())) {
1281  // Presence and typing.
1282  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1283  symbolTable.lookupNearestSymbolFrom(getOperation(),
1284  cast<SymbolRefAttr>(matcher)));
1285  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1286  symbolTable.lookupNearestSymbolFrom(getOperation(),
1287  cast<SymbolRefAttr>(action)));
1288  if (!matcherSymbol ||
1289  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1290  return emitError() << "unresolved matcher symbol " << matcher;
1291  if (!actionSymbol ||
1292  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1293  return emitError() << "unresolved action symbol " << action;
1294 
1295  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1296  /*emitWarnings=*/false,
1297  /*alsoVerifyInternal=*/true)
1298  .checkAndReport())) {
1299  return failure();
1300  }
1301  if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1302  /*emitWarnings=*/false,
1303  /*alsoVerifyInternal=*/true)
1304  .checkAndReport())) {
1305  return failure();
1306  }
1307 
1308  // Input -> matcher forwarding.
1309  TypeRange operandTypes = getOperandTypes();
1310  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1311  if (operandTypes.size() != matcherArguments.size()) {
1313  emitError() << "the number of operands (" << operandTypes.size()
1314  << ") doesn't match the number of matcher arguments ("
1315  << matcherArguments.size() << ") for " << matcher;
1316  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1317  return diag;
1318  }
1319  for (auto &&[i, operand, argument] :
1320  llvm::enumerate(operandTypes, matcherArguments)) {
1321  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1323  emitOpError()
1324  << "does not expect matcher symbol to consume its operand #" << i;
1325  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1326  return diag;
1327  }
1328 
1329  if (implementSameTransformInterface(operand, argument))
1330  continue;
1331 
1333  emitError()
1334  << "mismatching type interfaces for operand and matcher argument #"
1335  << i << " of matcher " << matcher;
1336  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1337  return diag;
1338  }
1339 
1340  // Matcher -> action forwarding.
1341  TypeRange matcherResults = matcherSymbol.getResultTypes();
1342  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1343  if (matcherResults.size() != actionArguments.size()) {
1344  return emitError() << "mismatching number of matcher results and "
1345  "action arguments between "
1346  << matcher << " (" << matcherResults.size() << ") and "
1347  << action << " (" << actionArguments.size() << ")";
1348  }
1349  for (auto &&[i, matcherType, actionType] :
1350  llvm::enumerate(matcherResults, actionArguments)) {
1351  if (implementSameTransformInterface(matcherType, actionType))
1352  continue;
1353 
1354  return emitError() << "mismatching type interfaces for matcher result "
1355  "and action argument #"
1356  << i << "of matcher " << matcher << " and action "
1357  << action;
1358  }
1359 
1360  // Action -> result forwarding.
1361  TypeRange actionResults = actionSymbol.getResultTypes();
1362  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1363  if (actionResults.size() != resultTypes.size()) {
1365  emitError() << "the number of action results ("
1366  << actionResults.size() << ") for " << action
1367  << " doesn't match the number of extra op results ("
1368  << resultTypes.size() << ")";
1369  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1370  return diag;
1371  }
1372  for (auto &&[i, resultType, actionType] :
1373  llvm::enumerate(resultTypes, actionResults)) {
1374  if (implementSameTransformInterface(resultType, actionType))
1375  continue;
1376 
1378  emitError() << "mismatching type interfaces for action result #" << i
1379  << " of action " << action << " and op result";
1380  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1381  return diag;
1382  }
1383  }
1384  return success();
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // ForeachOp
1389 //===----------------------------------------------------------------------===//
1390 
1392 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1393  transform::TransformResults &results,
1394  transform::TransformState &state) {
1395  // We store the payloads before executing the body as ops may be removed from
1396  // the mapping by the TrackingRewriter while iteration is in progress.
1398  detail::prepareValueMappings(payloads, getTargets(), state);
1399  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1400  bool withZipShortest = getWithZipShortest();
1401 
1402  // In case of `zip_shortest`, set the number of iterations to the
1403  // smallest payload in the targets.
1404  if (withZipShortest) {
1405  numIterations =
1406  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1407  const SmallVector<MappedValue> &B) {
1408  return A.size() < B.size();
1409  })->size();
1410 
1411  for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1412  payloads[argIdx].resize(numIterations);
1413  }
1414 
1415  // As we will be "zipping" over them, check all payloads have the same size.
1416  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1417  // when true.
1418  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1419  argIdx++) {
1420  if (payloads[argIdx].size() != numIterations) {
1421  return emitSilenceableError()
1422  << "prior targets' payload size (" << numIterations
1423  << ") differs from payload size (" << payloads[argIdx].size()
1424  << ") of target " << getTargets()[argIdx];
1425  }
1426  }
1427 
1428  // Start iterating, indexing into payloads to obtain the right arguments to
1429  // call the body with - each slice of payloads at the same argument index
1430  // corresponding to a tuple to use as the body's block arguments.
1431  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1432  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1433  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1434  auto scope = state.make_region_scope(getBody());
1435  // Set up arguments to the region's block.
1436  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1437  MappedValue argument = payloads[argIdx][iterIdx];
1438  // Note that each blockArg's handle gets associated with just a single
1439  // element from the corresponding target's payload.
1440  if (failed(state.mapBlockArgument(blockArg, {argument})))
1442  }
1443 
1444  // Execute loop body.
1445  for (Operation &transform : getBody().front().without_terminator()) {
1446  DiagnosedSilenceableFailure result = state.applyTransform(
1447  llvm::cast<transform::TransformOpInterface>(transform));
1448  if (!result.succeeded())
1449  return result;
1450  }
1451 
1452  // Append yielded payloads to corresponding results from prior iterations.
1453  OperandRange yieldOperands = getYieldOp().getOperands();
1454  for (auto &&[result, yieldOperand, resTuple] :
1455  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1456  // NB: each iteration we add any number of ops/vals/params to a result.
1457  if (isa<TransformHandleTypeInterface>(result.getType()))
1458  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1459  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1460  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1461  else if (isa<TransformParamTypeInterface>(result.getType()))
1462  llvm::append_range(resTuple, state.getParams(yieldOperand));
1463  else
1464  assert(false && "unhandled handle type");
1465  }
1466 
1467  // Associate the accumulated result payloads to the op's actual results.
1468  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1469  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1470 
1472 }
1473 
1474 void transform::ForeachOp::getEffects(
1475  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1476  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1477  // arity errors, this method might get called before/in absence of `verify()`.
1478  for (auto &&[target, blockArg] :
1479  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1480  BlockArgument blockArgument = blockArg;
1481  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1482  return isHandleConsumed(blockArgument,
1483  cast<TransformOpInterface>(&op));
1484  })) {
1485  consumesHandle(target, effects);
1486  } else {
1487  onlyReadsHandle(target, effects);
1488  }
1489  }
1490 
1491  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1492  return doesModifyPayload(cast<TransformOpInterface>(&op));
1493  })) {
1494  modifiesPayload(effects);
1495  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1496  return doesReadPayload(cast<TransformOpInterface>(&op));
1497  })) {
1498  onlyReadsPayload(effects);
1499  }
1500 
1501  producesHandle(getOperation()->getOpResults(), effects);
1502 }
1503 
1504 void transform::ForeachOp::getSuccessorRegions(
1505  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1506  Region *bodyRegion = &getBody();
1507  if (point.isParent()) {
1508  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1509  return;
1510  }
1511 
1512  // Branch back to the region or the parent.
1513  assert(point == getBody() && "unexpected region index");
1514  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1515  regions.emplace_back();
1516 }
1517 
1519 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1520  // Each block argument handle is mapped to a subset (one op to be precise)
1521  // of the payload of the corresponding `targets` operand of ForeachOp.
1522  assert(point == getBody() && "unexpected region index");
1523  return getOperation()->getOperands();
1524 }
1525 
1526 transform::YieldOp transform::ForeachOp::getYieldOp() {
1527  return cast<transform::YieldOp>(getBody().front().getTerminator());
1528 }
1529 
1530 LogicalResult transform::ForeachOp::verify() {
1531  for (auto [targetOpt, bodyArgOpt] :
1532  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1533  if (!targetOpt || !bodyArgOpt)
1534  return emitOpError() << "expects the same number of targets as the body "
1535  "has block arguments";
1536  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1537  return emitOpError(
1538  "expects co-indexed targets and the body's "
1539  "block arguments to have the same op/value/param type");
1540  }
1541 
1542  for (auto [resultOpt, yieldOperandOpt] :
1543  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1544  if (!resultOpt || !yieldOperandOpt)
1545  return emitOpError() << "expects the same number of results as the "
1546  "yield terminator has operands";
1547  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1548  return emitOpError("expects co-indexed results and yield "
1549  "operands to have the same op/value/param type");
1550  }
1551 
1552  return success();
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // GetParentOp
1557 //===----------------------------------------------------------------------===//
1558 
1560 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1561  transform::TransformResults &results,
1562  transform::TransformState &state) {
1563  SmallVector<Operation *> parents;
1564  DenseSet<Operation *> resultSet;
1565  for (Operation *target : state.getPayloadOps(getTarget())) {
1566  Operation *parent = target;
1567  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1568  parent = parent->getParentOp();
1569  while (parent) {
1570  bool checkIsolatedFromAbove =
1571  !getIsolatedFromAbove() ||
1573  bool checkOpName = !getOpName().has_value() ||
1574  parent->getName().getStringRef() == *getOpName();
1575  if (checkIsolatedFromAbove && checkOpName)
1576  break;
1577  parent = parent->getParentOp();
1578  }
1579  if (!parent) {
1580  if (getAllowEmptyResults()) {
1581  results.set(llvm::cast<OpResult>(getResult()), parents);
1583  }
1585  emitSilenceableError()
1586  << "could not find a parent op that matches all requirements";
1587  diag.attachNote(target->getLoc()) << "target op";
1588  return diag;
1589  }
1590  }
1591  if (getDeduplicate()) {
1592  if (resultSet.insert(parent).second)
1593  parents.push_back(parent);
1594  } else {
1595  parents.push_back(parent);
1596  }
1597  }
1598  results.set(llvm::cast<OpResult>(getResult()), parents);
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // GetConsumersOfResult
1604 //===----------------------------------------------------------------------===//
1605 
1607 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1608  transform::TransformResults &results,
1609  transform::TransformState &state) {
1610  int64_t resultNumber = getResultNumber();
1611  auto payloadOps = state.getPayloadOps(getTarget());
1612  if (std::empty(payloadOps)) {
1613  results.set(cast<OpResult>(getResult()), {});
1615  }
1616  if (!llvm::hasSingleElement(payloadOps))
1617  return emitDefiniteFailure()
1618  << "handle must be mapped to exactly one payload op";
1619 
1620  Operation *target = *payloadOps.begin();
1621  if (target->getNumResults() <= resultNumber)
1622  return emitDefiniteFailure() << "result number overflow";
1623  results.set(llvm::cast<OpResult>(getResult()),
1624  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1626 }
1627 
1628 //===----------------------------------------------------------------------===//
1629 // GetDefiningOp
1630 //===----------------------------------------------------------------------===//
1631 
1633 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1634  transform::TransformResults &results,
1635  transform::TransformState &state) {
1636  SmallVector<Operation *> definingOps;
1637  for (Value v : state.getPayloadValues(getTarget())) {
1638  if (llvm::isa<BlockArgument>(v)) {
1640  emitSilenceableError() << "cannot get defining op of block argument";
1641  diag.attachNote(v.getLoc()) << "target value";
1642  return diag;
1643  }
1644  definingOps.push_back(v.getDefiningOp());
1645  }
1646  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1648 }
1649 
1650 //===----------------------------------------------------------------------===//
1651 // GetProducerOfOperand
1652 //===----------------------------------------------------------------------===//
1653 
1655 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1656  transform::TransformResults &results,
1657  transform::TransformState &state) {
1658  int64_t operandNumber = getOperandNumber();
1659  SmallVector<Operation *> producers;
1660  for (Operation *target : state.getPayloadOps(getTarget())) {
1661  Operation *producer =
1662  target->getNumOperands() <= operandNumber
1663  ? nullptr
1664  : target->getOperand(operandNumber).getDefiningOp();
1665  if (!producer) {
1667  emitSilenceableError()
1668  << "could not find a producer for operand number: " << operandNumber
1669  << " of " << *target;
1670  diag.attachNote(target->getLoc()) << "target op";
1671  return diag;
1672  }
1673  producers.push_back(producer);
1674  }
1675  results.set(llvm::cast<OpResult>(getResult()), producers);
1677 }
1678 
1679 //===----------------------------------------------------------------------===//
1680 // GetOperandOp
1681 //===----------------------------------------------------------------------===//
1682 
1684 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1685  transform::TransformResults &results,
1686  transform::TransformState &state) {
1687  SmallVector<Value> operands;
1688  for (Operation *target : state.getPayloadOps(getTarget())) {
1689  SmallVector<int64_t> operandPositions;
1691  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1692  target->getNumOperands(), operandPositions);
1693  if (diag.isSilenceableFailure()) {
1694  diag.attachNote(target->getLoc())
1695  << "while considering positions of this payload operation";
1696  return diag;
1697  }
1698  llvm::append_range(operands,
1699  llvm::map_range(operandPositions, [&](int64_t pos) {
1700  return target->getOperand(pos);
1701  }));
1702  }
1703  results.setValues(cast<OpResult>(getResult()), operands);
1705 }
1706 
1707 LogicalResult transform::GetOperandOp::verify() {
1708  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1709  getIsInverted(), getIsAll());
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // GetResultOp
1714 //===----------------------------------------------------------------------===//
1715 
1717 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1718  transform::TransformResults &results,
1719  transform::TransformState &state) {
1720  SmallVector<Value> opResults;
1721  for (Operation *target : state.getPayloadOps(getTarget())) {
1722  SmallVector<int64_t> resultPositions;
1724  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1725  target->getNumResults(), resultPositions);
1726  if (diag.isSilenceableFailure()) {
1727  diag.attachNote(target->getLoc())
1728  << "while considering positions of this payload operation";
1729  return diag;
1730  }
1731  llvm::append_range(opResults,
1732  llvm::map_range(resultPositions, [&](int64_t pos) {
1733  return target->getResult(pos);
1734  }));
1735  }
1736  results.setValues(cast<OpResult>(getResult()), opResults);
1738 }
1739 
1740 LogicalResult transform::GetResultOp::verify() {
1741  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1742  getIsInverted(), getIsAll());
1743 }
1744 
1745 //===----------------------------------------------------------------------===//
1746 // GetTypeOp
1747 //===----------------------------------------------------------------------===//
1748 
1749 void transform::GetTypeOp::getEffects(
1750  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1751  onlyReadsHandle(getValueMutable(), effects);
1752  producesHandle(getOperation()->getOpResults(), effects);
1753  onlyReadsPayload(effects);
1754 }
1755 
1757 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1758  transform::TransformResults &results,
1759  transform::TransformState &state) {
1760  SmallVector<Attribute> params;
1761  for (Value value : state.getPayloadValues(getValue())) {
1762  Type type = value.getType();
1763  if (getElemental()) {
1764  if (auto shaped = dyn_cast<ShapedType>(type)) {
1765  type = shaped.getElementType();
1766  }
1767  }
1768  params.push_back(TypeAttr::get(type));
1769  }
1770  results.setParams(cast<OpResult>(getResult()), params);
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // IncludeOp
1776 //===----------------------------------------------------------------------===//
1777 
1778 /// Applies the transform ops contained in `block`. Maps `results` to the same
1779 /// values as the operands of the block terminator.
1781 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1783  transform::TransformResults &results) {
1784  // Apply the sequenced ops one by one.
1785  for (Operation &transform : block.without_terminator()) {
1787  state.applyTransform(cast<transform::TransformOpInterface>(transform));
1788  if (result.isDefiniteFailure())
1789  return result;
1790 
1791  if (result.isSilenceableFailure()) {
1792  if (mode == transform::FailurePropagationMode::Propagate) {
1793  // Propagate empty results in case of early exit.
1794  forwardEmptyOperands(&block, state, results);
1795  return result;
1796  }
1797  (void)result.silence();
1798  }
1799  }
1800 
1801  // Forward the operation mapping for values yielded from the sequence to the
1802  // values produced by the sequence op.
1803  transform::detail::forwardTerminatorOperands(&block, state, results);
1805 }
1806 
1808 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1809  transform::TransformResults &results,
1810  transform::TransformState &state) {
1811  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1812  getOperation(), getTarget());
1813  assert(callee && "unverified reference to unknown symbol");
1814 
1815  if (callee.isExternal())
1816  return emitDefiniteFailure() << "unresolved external named sequence";
1817 
1818  // Map operands to block arguments.
1820  detail::prepareValueMappings(mappings, getOperands(), state);
1821  auto scope = state.make_region_scope(callee.getBody());
1822  for (auto &&[arg, map] :
1823  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1824  if (failed(state.mapBlockArgument(arg, map)))
1826  }
1827 
1829  callee.getBody().front(), getFailurePropagationMode(), state, results);
1830  mappings.clear();
1832  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1833  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1834  results.setMappedValues(result, mapping);
1835  return result;
1836 }
1837 
1839 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1840 
1841 void transform::IncludeOp::getEffects(
1842  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1843  // Always mark as modifying the payload.
1844  // TODO: a mechanism to annotate effects on payload. Even when all handles are
1845  // only read, the payload may still be modified, so we currently stay on the
1846  // conservative side and always indicate modification. This may prevent some
1847  // code reordering.
1848  modifiesPayload(effects);
1849 
1850  // Results are always produced.
1851  producesHandle(getOperation()->getOpResults(), effects);
1852 
1853  // Adds default effects to operands and results. This will be added if
1854  // preconditions fail so the trait verifier doesn't complain about missing
1855  // effects and the real precondition failure is reported later on.
1856  auto defaultEffects = [&] {
1857  onlyReadsHandle(getOperation()->getOpOperands(), effects);
1858  };
1859 
1860  // Bail if the callee is unknown. This may run as part of the verification
1861  // process before we verified the validity of the callee or of this op.
1862  auto target =
1863  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1864  if (!target)
1865  return defaultEffects();
1866  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1867  getOperation(), getTarget());
1868  if (!callee)
1869  return defaultEffects();
1870  DiagnosedSilenceableFailure earlyVerifierResult =
1871  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1872  if (!earlyVerifierResult.succeeded()) {
1873  (void)earlyVerifierResult.silence();
1874  return defaultEffects();
1875  }
1876 
1877  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1878  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1879  consumesHandle(getOperation()->getOpOperand(i), effects);
1880  else
1881  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
1882  }
1883 }
1884 
1885 LogicalResult
1886 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1887  // Access through indirection and do additional checking because this may be
1888  // running before the main op verifier.
1889  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1890  if (!targetAttr)
1891  return emitOpError() << "expects a 'target' symbol reference attribute";
1892 
1893  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1894  *this, targetAttr);
1895  if (!target)
1896  return emitOpError() << "does not reference a named transform sequence";
1897 
1898  FunctionType fnType = target.getFunctionType();
1899  if (fnType.getNumInputs() != getNumOperands())
1900  return emitError("incorrect number of operands for callee");
1901 
1902  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1903  if (getOperand(i).getType() != fnType.getInput(i)) {
1904  return emitOpError("operand type mismatch: expected operand type ")
1905  << fnType.getInput(i) << ", but provided "
1906  << getOperand(i).getType() << " for operand number " << i;
1907  }
1908  }
1909 
1910  if (fnType.getNumResults() != getNumResults())
1911  return emitError("incorrect number of results for callee");
1912 
1913  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1914  Type resultType = getResult(i).getType();
1915  Type funcType = fnType.getResult(i);
1916  if (!implementSameTransformInterface(resultType, funcType)) {
1917  return emitOpError() << "type of result #" << i
1918  << " must implement the same transform dialect "
1919  "interface as the corresponding callee result";
1920  }
1921  }
1922 
1924  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1925  /*alsoVerifyInternal=*/true)
1926  .checkAndReport();
1927 }
1928 
1929 //===----------------------------------------------------------------------===//
1930 // MatchOperationEmptyOp
1931 //===----------------------------------------------------------------------===//
1932 
1933 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1934  ::std::optional<::mlir::Operation *> maybeCurrent,
1936  if (!maybeCurrent.has_value()) {
1937  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1939  }
1940  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1941  return emitSilenceableError() << "operation is not empty";
1942 }
1943 
1944 //===----------------------------------------------------------------------===//
1945 // MatchOperationNameOp
1946 //===----------------------------------------------------------------------===//
1947 
1948 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1949  Operation *current, transform::TransformResults &results,
1950  transform::TransformState &state) {
1951  StringRef currentOpName = current->getName().getStringRef();
1952  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1953  if (acceptedAttr.getValue() == currentOpName)
1955  }
1956  return emitSilenceableError() << "wrong operation name";
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // MatchParamCmpIOp
1961 //===----------------------------------------------------------------------===//
1962 
1964 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1965  transform::TransformResults &results,
1966  transform::TransformState &state) {
1967  auto signedAPIntAsString = [&](const APInt &value) {
1968  std::string str;
1969  llvm::raw_string_ostream os(str);
1970  value.print(os, /*isSigned=*/true);
1971  return str;
1972  };
1973 
1974  ArrayRef<Attribute> params = state.getParams(getParam());
1975  ArrayRef<Attribute> references = state.getParams(getReference());
1976 
1977  if (params.size() != references.size()) {
1978  return emitSilenceableError()
1979  << "parameters have different payload lengths (" << params.size()
1980  << " vs " << references.size() << ")";
1981  }
1982 
1983  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1984  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1985  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1986  if (!intAttr || !refAttr) {
1987  return emitDefiniteFailure()
1988  << "non-integer parameter value not expected";
1989  }
1990  if (intAttr.getType() != refAttr.getType()) {
1991  return emitDefiniteFailure()
1992  << "mismatching integer attribute types in parameter #" << i;
1993  }
1994  APInt value = intAttr.getValue();
1995  APInt refValue = refAttr.getValue();
1996 
1997  // TODO: this copy will not be necessary in C++20.
1998  int64_t position = i;
1999  auto reportError = [&](StringRef direction) {
2001  emitSilenceableError() << "expected parameter to be " << direction
2002  << " " << signedAPIntAsString(refValue)
2003  << ", got " << signedAPIntAsString(value);
2004  diag.attachNote(getParam().getLoc())
2005  << "value # " << position
2006  << " associated with the parameter defined here";
2007  return diag;
2008  };
2009 
2010  switch (getPredicate()) {
2011  case MatchCmpIPredicate::eq:
2012  if (value.eq(refValue))
2013  break;
2014  return reportError("equal to");
2015  case MatchCmpIPredicate::ne:
2016  if (value.ne(refValue))
2017  break;
2018  return reportError("not equal to");
2019  case MatchCmpIPredicate::lt:
2020  if (value.slt(refValue))
2021  break;
2022  return reportError("less than");
2023  case MatchCmpIPredicate::le:
2024  if (value.sle(refValue))
2025  break;
2026  return reportError("less than or equal to");
2027  case MatchCmpIPredicate::gt:
2028  if (value.sgt(refValue))
2029  break;
2030  return reportError("greater than");
2031  case MatchCmpIPredicate::ge:
2032  if (value.sge(refValue))
2033  break;
2034  return reportError("greater than or equal to");
2035  }
2036  }
2038 }
2039 
2040 void transform::MatchParamCmpIOp::getEffects(
2041  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2042  onlyReadsHandle(getParamMutable(), effects);
2043  onlyReadsHandle(getReferenceMutable(), effects);
2044 }
2045 
2046 //===----------------------------------------------------------------------===//
2047 // ParamConstantOp
2048 //===----------------------------------------------------------------------===//
2049 
2051 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2052  transform::TransformResults &results,
2053  transform::TransformState &state) {
2054  results.setParams(cast<OpResult>(getParam()), {getValue()});
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // MergeHandlesOp
2060 //===----------------------------------------------------------------------===//
2061 
2063 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2064  transform::TransformResults &results,
2065  transform::TransformState &state) {
2066  ValueRange handles = getHandles();
2067  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2068  SmallVector<Operation *> operations;
2069  for (Value operand : handles)
2070  llvm::append_range(operations, state.getPayloadOps(operand));
2071  if (!getDeduplicate()) {
2072  results.set(llvm::cast<OpResult>(getResult()), operations);
2074  }
2075 
2076  SetVector<Operation *> uniqued(operations.begin(), operations.end());
2077  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2079  }
2080 
2081  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2082  SmallVector<Attribute> attrs;
2083  for (Value attribute : handles)
2084  llvm::append_range(attrs, state.getParams(attribute));
2085  if (!getDeduplicate()) {
2086  results.setParams(cast<OpResult>(getResult()), attrs);
2088  }
2089 
2090  SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
2091  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2093  }
2094 
2095  assert(
2096  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2097  "expected value handle type");
2098  SmallVector<Value> payloadValues;
2099  for (Value value : handles)
2100  llvm::append_range(payloadValues, state.getPayloadValues(value));
2101  if (!getDeduplicate()) {
2102  results.setValues(cast<OpResult>(getResult()), payloadValues);
2104  }
2105 
2106  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
2107  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2109 }
2110 
2111 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2112  // Handles may be the same if deduplicating is enabled.
2113  return getDeduplicate();
2114 }
2115 
2116 void transform::MergeHandlesOp::getEffects(
2117  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2118  onlyReadsHandle(getHandlesMutable(), effects);
2119  producesHandle(getOperation()->getOpResults(), effects);
2120 
2121  // There are no effects on the Payload IR as this is only a handle
2122  // manipulation.
2123 }
2124 
2125 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2126  if (getDeduplicate() || getHandles().size() != 1)
2127  return {};
2128 
2129  // If deduplication is not required and there is only one operand, it can be
2130  // used directly instead of merging.
2131  return getHandles().front();
2132 }
2133 
2134 //===----------------------------------------------------------------------===//
2135 // NamedSequenceOp
2136 //===----------------------------------------------------------------------===//
2137 
2139 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2140  transform::TransformResults &results,
2141  transform::TransformState &state) {
2142  if (isExternal())
2143  return emitDefiniteFailure() << "unresolved external named sequence";
2144 
2145  // Map the entry block argument to the list of operations.
2146  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2147  // without attaching the interface / trait since that is tailored to a
2148  // dangling top-level op that does not get "called".
2149  auto scope = state.make_region_scope(getBody());
2151  state, this->getOperation(), getBody())))
2153 
2154  return applySequenceBlock(getBody().front(),
2155  FailurePropagationMode::Propagate, state, results);
2156 }
2157 
2158 void transform::NamedSequenceOp::getEffects(
2159  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2160 
2162  OperationState &result) {
2164  parser, result, /*allowVariadic=*/false,
2165  getFunctionTypeAttrName(result.name),
2166  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2168  std::string &) { return builder.getFunctionType(inputs, results); },
2169  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2170 }
2171 
2174  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2175  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2176  getResAttrsAttrName());
2177 }
2178 
2179 /// Verifies that a symbol function-like transform dialect operation has the
2180 /// signature and the terminator that have conforming types, i.e., types
2181 /// implementing the same transform dialect type interface. If `allowExternal`
2182 /// is set, allow external symbols (declarations) and don't check the terminator
2183 /// as it may not exist.
2185 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2186  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2189  << "cannot be defined inside another transform op";
2190  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2191  return diag;
2192  }
2193 
2194  if (op.isExternal() || op.getFunctionBody().empty()) {
2195  if (allowExternal)
2197 
2198  return emitSilenceableFailure(op) << "cannot be external";
2199  }
2200 
2201  if (op.getFunctionBody().front().empty())
2202  return emitSilenceableFailure(op) << "expected a non-empty body block";
2203 
2204  Operation *terminator = &op.getFunctionBody().front().back();
2205  if (!isa<transform::YieldOp>(terminator)) {
2207  << "expected '"
2208  << transform::YieldOp::getOperationName()
2209  << "' as terminator";
2210  diag.attachNote(terminator->getLoc()) << "terminator";
2211  return diag;
2212  }
2213 
2214  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2215  return emitSilenceableFailure(terminator)
2216  << "expected terminator to have as many operands as the parent op "
2217  "has results";
2218  }
2219  for (auto [i, operandType, resultType] : llvm::zip_equal(
2220  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2221  terminator->getOperands().getType(), op.getResultTypes())) {
2222  if (operandType == resultType)
2223  continue;
2224  return emitSilenceableFailure(terminator)
2225  << "the type of the terminator operand #" << i
2226  << " must match the type of the corresponding parent op result ("
2227  << operandType << " vs " << resultType << ")";
2228  }
2229 
2231 }
2232 
2233 /// Verification of a NamedSequenceOp. This does not report the error
2234 /// immediately, so it can be used to check for op's well-formedness before the
2235 /// verifier runs, e.g., during trait verification.
2237 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2238  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2239  if (!parent->getAttr(
2240  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2243  << "expects the parent symbol table to have the '"
2244  << transform::TransformDialect::kWithNamedSequenceAttrName
2245  << "' attribute";
2246  diag.attachNote(parent->getLoc()) << "symbol table operation";
2247  return diag;
2248  }
2249  }
2250 
2251  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2254  << "cannot be defined inside another transform op";
2255  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2256  return diag;
2257  }
2258 
2259  if (op.isExternal() || op.getBody().empty())
2260  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2261  emitWarnings);
2262 
2263  if (op.getBody().front().empty())
2264  return emitSilenceableFailure(op) << "expected a non-empty body block";
2265 
2266  Operation *terminator = &op.getBody().front().back();
2267  if (!isa<transform::YieldOp>(terminator)) {
2269  << "expected '"
2270  << transform::YieldOp::getOperationName()
2271  << "' as terminator";
2272  diag.attachNote(terminator->getLoc()) << "terminator";
2273  return diag;
2274  }
2275 
2276  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2277  return emitSilenceableFailure(terminator)
2278  << "expected terminator to have as many operands as the parent op "
2279  "has results";
2280  }
2281  for (auto [i, operandType, resultType] :
2282  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2283  terminator->getOperands().getType(),
2284  op.getFunctionType().getResults())) {
2285  if (operandType == resultType)
2286  continue;
2287  return emitSilenceableFailure(terminator)
2288  << "the type of the terminator operand #" << i
2289  << " must match the type of the corresponding parent op result ("
2290  << operandType << " vs " << resultType << ")";
2291  }
2292 
2293  auto funcOp = cast<FunctionOpInterface>(*op);
2295  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2296  if (!diag.succeeded())
2297  return diag;
2298 
2299  return verifyYieldingSingleBlockOp(funcOp,
2300  /*allowExternal=*/true);
2301 }
2302 
2303 LogicalResult transform::NamedSequenceOp::verify() {
2304  // Actual verification happens in a separate function for reusability.
2305  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2306 }
2307 
2308 template <typename FnTy>
2309 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2310  Type bbArgType, TypeRange extraBindingTypes,
2311  FnTy bodyBuilder) {
2312  SmallVector<Type> types;
2313  types.reserve(1 + extraBindingTypes.size());
2314  types.push_back(bbArgType);
2315  llvm::append_range(types, extraBindingTypes);
2316 
2317  OpBuilder::InsertionGuard guard(builder);
2318  Region *region = state.regions.back().get();
2319  Block *bodyBlock =
2320  builder.createBlock(region, region->begin(), types,
2321  SmallVector<Location>(types.size(), state.location));
2322 
2323  // Populate body.
2324  builder.setInsertionPointToStart(bodyBlock);
2325  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2326  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2327  } else {
2328  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2329  bodyBlock->getArguments().drop_front());
2330  }
2331 }
2332 
2333 void transform::NamedSequenceOp::build(OpBuilder &builder,
2334  OperationState &state, StringRef symName,
2335  Type rootType, TypeRange resultTypes,
2336  SequenceBodyBuilderFn bodyBuilder,
2338  ArrayRef<DictionaryAttr> argAttrs) {
2339  state.addAttribute(SymbolTable::getSymbolAttrName(),
2340  builder.getStringAttr(symName));
2341  state.addAttribute(getFunctionTypeAttrName(state.name),
2343  rootType, resultTypes)));
2344  state.attributes.append(attrs.begin(), attrs.end());
2345  state.addRegion();
2346 
2347  buildSequenceBody(builder, state, rootType,
2348  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2349 }
2350 
2351 //===----------------------------------------------------------------------===//
2352 // NumAssociationsOp
2353 //===----------------------------------------------------------------------===//
2354 
2356 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2357  transform::TransformResults &results,
2358  transform::TransformState &state) {
2359  size_t numAssociations =
2361  .Case([&](TransformHandleTypeInterface opHandle) {
2362  return llvm::range_size(state.getPayloadOps(getHandle()));
2363  })
2364  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2365  return llvm::range_size(state.getPayloadValues(getHandle()));
2366  })
2367  .Case([&](TransformParamTypeInterface param) {
2368  return llvm::range_size(state.getParams(getHandle()));
2369  })
2370  .Default([](Type) {
2371  llvm_unreachable("unknown kind of transform dialect type");
2372  return 0;
2373  });
2374  results.setParams(cast<OpResult>(getNum()),
2375  rewriter.getI64IntegerAttr(numAssociations));
2377 }
2378 
2379 LogicalResult transform::NumAssociationsOp::verify() {
2380  // Verify that the result type accepts an i64 attribute as payload.
2381  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2382  return resultType
2383  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2384  .checkAndReport();
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // SelectOp
2389 //===----------------------------------------------------------------------===//
2390 
2392 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2393  transform::TransformResults &results,
2394  transform::TransformState &state) {
2395  SmallVector<Operation *> result;
2396  auto payloadOps = state.getPayloadOps(getTarget());
2397  for (Operation *op : payloadOps) {
2398  if (op->getName().getStringRef() == getOpName())
2399  result.push_back(op);
2400  }
2401  results.set(cast<OpResult>(getResult()), result);
2403 }
2404 
2405 //===----------------------------------------------------------------------===//
2406 // SplitHandleOp
2407 //===----------------------------------------------------------------------===//
2408 
2409 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2410  Value target, int64_t numResultHandles) {
2411  result.addOperands(target);
2412  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2413 }
2414 
2416 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2417  transform::TransformResults &results,
2418  transform::TransformState &state) {
2419  int64_t numPayloads =
2421  .Case<TransformHandleTypeInterface>([&](auto x) {
2422  return llvm::range_size(state.getPayloadOps(getHandle()));
2423  })
2424  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2425  return llvm::range_size(state.getPayloadValues(getHandle()));
2426  })
2427  .Case<TransformParamTypeInterface>([&](auto x) {
2428  return llvm::range_size(state.getParams(getHandle()));
2429  })
2430  .Default([](auto x) {
2431  llvm_unreachable("unknown transform dialect type interface");
2432  return -1;
2433  });
2434 
2435  auto produceNumOpsError = [&]() {
2436  return emitSilenceableError()
2437  << getHandle() << " expected to contain " << this->getNumResults()
2438  << " payloads but it contains " << numPayloads << " payloads";
2439  };
2440 
2441  // Fail if there are more payload ops than results and no overflow result was
2442  // specified.
2443  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2444  return produceNumOpsError();
2445 
2446  // Fail if there are more results than payload ops. Unless:
2447  // - "fail_on_payload_too_small" is set to "false", or
2448  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2449  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2450  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2451  return produceNumOpsError();
2452 
2453  // Distribute payloads.
2454  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2455  if (getOverflowResult())
2456  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2457 
2458  auto container = [&]() {
2459  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2460  return llvm::map_to_vector(
2461  state.getPayloadOps(getHandle()),
2462  [](Operation *op) -> MappedValue { return op; });
2463  }
2464  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2465  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2466  [](Value v) -> MappedValue { return v; });
2467  }
2468  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2469  "unsupported kind of transform dialect type");
2470  return llvm::map_to_vector(state.getParams(getHandle()),
2471  [](Attribute a) -> MappedValue { return a; });
2472  }();
2473 
2474  for (auto &&en : llvm::enumerate(container)) {
2475  int64_t resultNum = en.index();
2476  if (resultNum >= getNumResults())
2477  resultNum = *getOverflowResult();
2478  resultHandles[resultNum].push_back(en.value());
2479  }
2480 
2481  // Set transform op results.
2482  for (auto &&it : llvm::enumerate(resultHandles))
2483  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2484  it.value());
2485 
2487 }
2488 
2489 void transform::SplitHandleOp::getEffects(
2490  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2491  onlyReadsHandle(getHandleMutable(), effects);
2492  producesHandle(getOperation()->getOpResults(), effects);
2493  // There are no effects on the Payload IR as this is only a handle
2494  // manipulation.
2495 }
2496 
2497 LogicalResult transform::SplitHandleOp::verify() {
2498  if (getOverflowResult().has_value() &&
2499  !(*getOverflowResult() < getNumResults()))
2500  return emitOpError("overflow_result is not a valid result index");
2501 
2502  for (Type resultType : getResultTypes()) {
2503  if (implementSameTransformInterface(getHandle().getType(), resultType))
2504  continue;
2505 
2506  return emitOpError("expects result types to implement the same transform "
2507  "interface as the operand type");
2508  }
2509 
2510  return success();
2511 }
2512 
2513 //===----------------------------------------------------------------------===//
2514 // ReplicateOp
2515 //===----------------------------------------------------------------------===//
2516 
2518 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2519  transform::TransformResults &results,
2520  transform::TransformState &state) {
2521  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2522  for (const auto &en : llvm::enumerate(getHandles())) {
2523  Value handle = en.value();
2524  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2525  SmallVector<Operation *> current =
2526  llvm::to_vector(state.getPayloadOps(handle));
2527  SmallVector<Operation *> payload;
2528  payload.reserve(numRepetitions * current.size());
2529  for (unsigned i = 0; i < numRepetitions; ++i)
2530  llvm::append_range(payload, current);
2531  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2532  } else {
2533  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2534  "expected param type");
2535  ArrayRef<Attribute> current = state.getParams(handle);
2536  SmallVector<Attribute> params;
2537  params.reserve(numRepetitions * current.size());
2538  for (unsigned i = 0; i < numRepetitions; ++i)
2539  llvm::append_range(params, current);
2540  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2541  params);
2542  }
2543  }
2545 }
2546 
2547 void transform::ReplicateOp::getEffects(
2548  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2549  onlyReadsHandle(getPatternMutable(), effects);
2550  onlyReadsHandle(getHandlesMutable(), effects);
2551  producesHandle(getOperation()->getOpResults(), effects);
2552 }
2553 
2554 //===----------------------------------------------------------------------===//
2555 // SequenceOp
2556 //===----------------------------------------------------------------------===//
2557 
2559 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2560  transform::TransformResults &results,
2561  transform::TransformState &state) {
2562  // Map the entry block argument to the list of operations.
2563  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2564  if (failed(mapBlockArguments(state)))
2566 
2567  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2568  results);
2569 }
2570 
2571 static ParseResult parseSequenceOpOperands(
2572  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2573  Type &rootType,
2574  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2575  SmallVectorImpl<Type> &extraBindingTypes) {
2576  OpAsmParser::UnresolvedOperand rootOperand;
2577  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2578  if (!hasRoot.has_value()) {
2579  root = std::nullopt;
2580  return success();
2581  }
2582  if (failed(hasRoot.value()))
2583  return failure();
2584  root = rootOperand;
2585 
2586  if (succeeded(parser.parseOptionalComma())) {
2587  if (failed(parser.parseOperandList(extraBindings)))
2588  return failure();
2589  }
2590  if (failed(parser.parseColon()))
2591  return failure();
2592 
2593  // The paren is truly optional.
2594  (void)parser.parseOptionalLParen();
2595 
2596  if (failed(parser.parseType(rootType))) {
2597  return failure();
2598  }
2599 
2600  if (!extraBindings.empty()) {
2601  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2602  return failure();
2603  }
2604 
2605  if (extraBindingTypes.size() != extraBindings.size()) {
2606  return parser.emitError(parser.getNameLoc(),
2607  "expected types to be provided for all operands");
2608  }
2609 
2610  // The paren is truly optional.
2611  (void)parser.parseOptionalRParen();
2612  return success();
2613 }
2614 
2616  Value root, Type rootType,
2617  ValueRange extraBindings,
2618  TypeRange extraBindingTypes) {
2619  if (!root)
2620  return;
2621 
2622  printer << root;
2623  bool hasExtras = !extraBindings.empty();
2624  if (hasExtras) {
2625  printer << ", ";
2626  printer.printOperands(extraBindings);
2627  }
2628 
2629  printer << " : ";
2630  if (hasExtras)
2631  printer << "(";
2632 
2633  printer << rootType;
2634  if (hasExtras)
2635  printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2636 }
2637 
2638 /// Returns `true` if the given op operand may be consuming the handle value in
2639 /// the Transform IR. That is, if it may have a Free effect on it.
2641  // Conservatively assume the effect being present in absence of the interface.
2642  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2643  if (!iface)
2644  return true;
2645 
2646  return isHandleConsumed(use.get(), iface);
2647 }
2648 
2649 LogicalResult
2651  function_ref<InFlightDiagnostic()> reportError) {
2652  OpOperand *potentialConsumer = nullptr;
2653  for (OpOperand &use : value.getUses()) {
2654  if (!isValueUsePotentialConsumer(use))
2655  continue;
2656 
2657  if (!potentialConsumer) {
2658  potentialConsumer = &use;
2659  continue;
2660  }
2661 
2662  InFlightDiagnostic diag = reportError()
2663  << " has more than one potential consumer";
2664  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2665  << "used here as operand #" << potentialConsumer->getOperandNumber();
2666  diag.attachNote(use.getOwner()->getLoc())
2667  << "used here as operand #" << use.getOperandNumber();
2668  return diag;
2669  }
2670 
2671  return success();
2672 }
2673 
2674 LogicalResult transform::SequenceOp::verify() {
2675  assert(getBodyBlock()->getNumArguments() >= 1 &&
2676  "the number of arguments must have been verified to be more than 1 by "
2677  "PossibleTopLevelTransformOpTrait");
2678 
2679  if (!getRoot() && !getExtraBindings().empty()) {
2680  return emitOpError()
2681  << "does not expect extra operands when used as top-level";
2682  }
2683 
2684  // Check if a block argument has more than one consuming use.
2685  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2686  if (failed(checkDoubleConsume(arg, [this, arg]() {
2687  return (emitOpError() << "block argument #" << arg.getArgNumber());
2688  }))) {
2689  return failure();
2690  }
2691  }
2692 
2693  // Check properties of the nested operations they cannot check themselves.
2694  for (Operation &child : *getBodyBlock()) {
2695  if (!isa<TransformOpInterface>(child) &&
2696  &child != &getBodyBlock()->back()) {
2698  emitOpError()
2699  << "expected children ops to implement TransformOpInterface";
2700  diag.attachNote(child.getLoc()) << "op without interface";
2701  return diag;
2702  }
2703 
2704  for (OpResult result : child.getResults()) {
2705  auto report = [&]() {
2706  return (child.emitError() << "result #" << result.getResultNumber());
2707  };
2708  if (failed(checkDoubleConsume(result, report)))
2709  return failure();
2710  }
2711  }
2712 
2713  if (!getBodyBlock()->mightHaveTerminator())
2714  return emitOpError() << "expects to have a terminator in the body";
2715 
2716  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2717  getOperation()->getResultTypes()) {
2718  InFlightDiagnostic diag = emitOpError()
2719  << "expects the types of the terminator operands "
2720  "to match the types of the result";
2721  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2722  return diag;
2723  }
2724  return success();
2725 }
2726 
2727 void transform::SequenceOp::getEffects(
2728  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2729  getPotentialTopLevelEffects(effects);
2730 }
2731 
2733 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2734  assert(point == getBody() && "unexpected region index");
2735  if (getOperation()->getNumOperands() > 0)
2736  return getOperation()->getOperands();
2737  return OperandRange(getOperation()->operand_end(),
2738  getOperation()->operand_end());
2739 }
2740 
2741 void transform::SequenceOp::getSuccessorRegions(
2742  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2743  if (point.isParent()) {
2744  Region *bodyRegion = &getBody();
2745  regions.emplace_back(bodyRegion, getNumOperands() != 0
2746  ? bodyRegion->getArguments()
2748  return;
2749  }
2750 
2751  assert(point == getBody() && "unexpected region index");
2752  regions.emplace_back(getOperation()->getResults());
2753 }
2754 
2755 void transform::SequenceOp::getRegionInvocationBounds(
2756  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2757  (void)operands;
2758  bounds.emplace_back(1, 1);
2759 }
2760 
2761 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2762  TypeRange resultTypes,
2763  FailurePropagationMode failurePropagationMode,
2764  Value root,
2765  SequenceBodyBuilderFn bodyBuilder) {
2766  build(builder, state, resultTypes, failurePropagationMode, root,
2767  /*extra_bindings=*/ValueRange());
2768  Type bbArgType = root.getType();
2769  buildSequenceBody(builder, state, bbArgType,
2770  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2771 }
2772 
2773 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2774  TypeRange resultTypes,
2775  FailurePropagationMode failurePropagationMode,
2776  Value root, ValueRange extraBindings,
2777  SequenceBodyBuilderArgsFn bodyBuilder) {
2778  build(builder, state, resultTypes, failurePropagationMode, root,
2779  extraBindings);
2780  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2781  bodyBuilder);
2782 }
2783 
2784 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2785  TypeRange resultTypes,
2786  FailurePropagationMode failurePropagationMode,
2787  Type bbArgType,
2788  SequenceBodyBuilderFn bodyBuilder) {
2789  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2790  /*extra_bindings=*/ValueRange());
2791  buildSequenceBody(builder, state, bbArgType,
2792  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2793 }
2794 
2795 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2796  TypeRange resultTypes,
2797  FailurePropagationMode failurePropagationMode,
2798  Type bbArgType, TypeRange extraBindingTypes,
2799  SequenceBodyBuilderArgsFn bodyBuilder) {
2800  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2801  /*extra_bindings=*/ValueRange());
2802  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2803 }
2804 
2805 //===----------------------------------------------------------------------===//
2806 // PrintOp
2807 //===----------------------------------------------------------------------===//
2808 
2809 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2810  StringRef name) {
2811  if (!name.empty())
2812  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2813 }
2814 
2815 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2816  Value target, StringRef name) {
2817  result.addOperands({target});
2818  build(builder, result, name);
2819 }
2820 
2822 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2823  transform::TransformResults &results,
2824  transform::TransformState &state) {
2825  llvm::outs() << "[[[ IR printer: ";
2826  if (getName().has_value())
2827  llvm::outs() << *getName() << " ";
2828 
2829  OpPrintingFlags printFlags;
2830  if (getAssumeVerified().value_or(false))
2831  printFlags.assumeVerified();
2832  if (getUseLocalScope().value_or(false))
2833  printFlags.useLocalScope();
2834  if (getSkipRegions().value_or(false))
2835  printFlags.skipRegions();
2836 
2837  if (!getTarget()) {
2838  llvm::outs() << "top-level ]]]\n";
2839  state.getTopLevel()->print(llvm::outs(), printFlags);
2840  llvm::outs() << "\n";
2841  llvm::outs().flush();
2843  }
2844 
2845  llvm::outs() << "]]]\n";
2846  for (Operation *target : state.getPayloadOps(getTarget())) {
2847  target->print(llvm::outs(), printFlags);
2848  llvm::outs() << "\n";
2849  }
2850 
2851  llvm::outs().flush();
2853 }
2854 
2855 void transform::PrintOp::getEffects(
2856  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2857  // We don't really care about mutability here, but `getTarget` now
2858  // unconditionally casts to a specific type before verification could run
2859  // here.
2860  if (!getTargetMutable().empty())
2861  onlyReadsHandle(getTargetMutable()[0], effects);
2862  onlyReadsPayload(effects);
2863 
2864  // There is no resource for stderr file descriptor, so just declare print
2865  // writes into the default resource.
2866  effects.emplace_back(MemoryEffects::Write::get());
2867 }
2868 
2869 //===----------------------------------------------------------------------===//
2870 // VerifyOp
2871 //===----------------------------------------------------------------------===//
2872 
2874 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2875  Operation *target,
2877  transform::TransformState &state) {
2878  if (failed(::mlir::verify(target))) {
2880  << "failed to verify payload op";
2881  diag.attachNote(target->getLoc()) << "payload op";
2882  return diag;
2883  }
2885 }
2886 
2887 void transform::VerifyOp::getEffects(
2888  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2889  transform::onlyReadsHandle(getTargetMutable(), effects);
2890 }
2891 
2892 //===----------------------------------------------------------------------===//
2893 // YieldOp
2894 //===----------------------------------------------------------------------===//
2895 
2896 void transform::YieldOp::getEffects(
2897  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2898  onlyReadsHandle(getOperandsMutable(), effects);
2899 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:734
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:297
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:289
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:283
This is a value defined by a result of an operation.
Definition: Value.h:433
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:231
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:204
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.