MLIR  20.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Verifier.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Pass/PassRegistry.h"
33 #include "mlir/Transforms/CSE.h"
37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include <optional>
45 
46 #define DEBUG_TYPE "transform-dialect"
47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
48 
49 #define DEBUG_TYPE_MATCHER "transform-matcher"
50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
52 
53 using namespace mlir;
54 
55 static ParseResult parseSequenceOpOperands(
56  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
57  Type &rootType,
58  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59  SmallVectorImpl<Type> &extraBindingTypes);
60 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
61  Value root, Type rootType,
62  ValueRange extraBindings,
63  TypeRange extraBindingTypes);
64 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
65  ArrayAttr matchers, ArrayAttr actions);
66 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
67  ArrayAttr &matchers,
68  ArrayAttr &actions);
69 
70 /// Helper function to check if the given transform op is contained in (or
71 /// equal to) the given payload target op. In that case, an error is returned.
72 /// Transforming transform IR that is currently executing is generally unsafe.
74 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
75  Operation *payload) {
76  Operation *transformAncestor = transform.getOperation();
77  while (transformAncestor) {
78  if (transformAncestor == payload) {
80  transform.emitDefiniteFailure()
81  << "cannot apply transform to itself (or one of its ancestors)";
82  diag.attachNote(payload->getLoc()) << "target payload op";
83  return diag;
84  }
85  transformAncestor = transformAncestor->getParentOp();
86  }
88 }
89 
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
92 
93 //===----------------------------------------------------------------------===//
94 // AlternativesOp
95 //===----------------------------------------------------------------------===//
96 
98 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
99  if (!point.isParent() && getOperation()->getNumOperands() == 1)
100  return getOperation()->getOperands();
101  return OperandRange(getOperation()->operand_end(),
102  getOperation()->operand_end());
103 }
104 
105 void transform::AlternativesOp::getSuccessorRegions(
106  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
107  for (Region &alternative : llvm::drop_begin(
108  getAlternatives(),
109  point.isParent() ? 0
110  : point.getRegionOrNull()->getRegionNumber() + 1)) {
111  regions.emplace_back(&alternative, !getOperands().empty()
112  ? alternative.getArguments()
114  }
115  if (!point.isParent())
116  regions.emplace_back(getOperation()->getResults());
117 }
118 
119 void transform::AlternativesOp::getRegionInvocationBounds(
120  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
121  (void)operands;
122  // The region corresponding to the first alternative is always executed, the
123  // remaining may or may not be executed.
124  bounds.reserve(getNumRegions());
125  bounds.emplace_back(1, 1);
126  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
127 }
128 
130  transform::TransformResults &results) {
131  for (const auto &res : block->getParentOp()->getOpResults())
132  results.set(res, {});
133 }
134 
136 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
138  transform::TransformState &state) {
139  SmallVector<Operation *> originals;
140  if (Value scopeHandle = getScope())
141  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
142  else
143  originals.push_back(state.getTopLevel());
144 
145  for (Operation *original : originals) {
146  if (original->isAncestor(getOperation())) {
147  auto diag = emitDefiniteFailure()
148  << "scope must not contain the transforms being applied";
149  diag.attachNote(original->getLoc()) << "scope";
150  return diag;
151  }
152  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
153  auto diag = emitDefiniteFailure()
154  << "only isolated-from-above ops can be alternative scopes";
155  diag.attachNote(original->getLoc()) << "scope";
156  return diag;
157  }
158  }
159 
160  for (Region &reg : getAlternatives()) {
161  // Clone the scope operations and make the transforms in this alternative
162  // region apply to them by virtue of mapping the block argument (the only
163  // visible handle) to the cloned scope operations. This effectively prevents
164  // the transformation from accessing any IR outside the scope.
165  auto scope = state.make_region_scope(reg);
166  auto clones = llvm::to_vector(
167  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
168  auto deleteClones = llvm::make_scope_exit([&] {
169  for (Operation *clone : clones)
170  clone->erase();
171  });
172  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
174 
175  bool failed = false;
176  for (Operation &transform : reg.front().without_terminator()) {
178  state.applyTransform(cast<TransformOpInterface>(transform));
179  if (result.isSilenceableFailure()) {
180  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
181  << "\n");
182  failed = true;
183  break;
184  }
185 
186  if (::mlir::failed(result.silence()))
188  }
189 
190  // If all operations in the given alternative succeeded, no need to consider
191  // the rest. Replace the original scoping operation with the clone on which
192  // the transformations were performed.
193  if (!failed) {
194  // We will be using the clones, so cancel their scheduled deletion.
195  deleteClones.release();
196  TrackingListener listener(state, *this);
197  IRRewriter rewriter(getContext(), &listener);
198  for (const auto &kvp : llvm::zip(originals, clones)) {
199  Operation *original = std::get<0>(kvp);
200  Operation *clone = std::get<1>(kvp);
201  original->getBlock()->getOperations().insert(original->getIterator(),
202  clone);
203  rewriter.replaceOp(original, clone->getResults());
204  }
205  detail::forwardTerminatorOperands(&reg.front(), state, results);
207  }
208  }
209  return emitSilenceableError() << "all alternatives failed";
210 }
211 
212 void transform::AlternativesOp::getEffects(
213  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
214  consumesHandle(getOperation()->getOpOperands(), effects);
215  producesHandle(getOperation()->getOpResults(), effects);
216  for (Region *region : getRegions()) {
217  if (!region->empty())
218  producesHandle(region->front().getArguments(), effects);
219  }
220  modifiesPayload(effects);
221 }
222 
223 LogicalResult transform::AlternativesOp::verify() {
224  for (Region &alternative : getAlternatives()) {
225  Block &block = alternative.front();
226  Operation *terminator = block.getTerminator();
227  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
228  InFlightDiagnostic diag = emitOpError()
229  << "expects terminator operands to have the "
230  "same type as results of the operation";
231  diag.attachNote(terminator->getLoc()) << "terminator";
232  return diag;
233  }
234  }
235 
236  return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // AnnotateOp
241 //===----------------------------------------------------------------------===//
242 
244 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
246  transform::TransformState &state) {
247  SmallVector<Operation *> targets =
248  llvm::to_vector(state.getPayloadOps(getTarget()));
249 
251  if (auto paramH = getParam()) {
252  ArrayRef<Attribute> params = state.getParams(paramH);
253  if (params.size() != 1) {
254  if (targets.size() != params.size()) {
255  return emitSilenceableError()
256  << "parameter and target have different payload lengths ("
257  << params.size() << " vs " << targets.size() << ")";
258  }
259  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
260  target->setAttr(getName(), attr);
262  }
263  attr = params[0];
264  }
265  for (auto *target : targets)
266  target->setAttr(getName(), attr);
268 }
269 
270 void transform::AnnotateOp::getEffects(
271  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
272  onlyReadsHandle(getTargetMutable(), effects);
273  onlyReadsHandle(getParamMutable(), effects);
274  modifiesPayload(effects);
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // ApplyCommonSubexpressionEliminationOp
279 //===----------------------------------------------------------------------===//
280 
282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
283  transform::TransformRewriter &rewriter, Operation *target,
284  ApplyToEachResultList &results, transform::TransformState &state) {
285  // Make sure that this transform is not applied to itself. Modifying the
286  // transform IR while it is being interpreted is generally dangerous.
287  DiagnosedSilenceableFailure payloadCheck =
289  if (!payloadCheck.succeeded())
290  return payloadCheck;
291 
292  DominanceInfo domInfo;
293  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
295 }
296 
297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
298  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
299  transform::onlyReadsHandle(getTargetMutable(), effects);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ApplyDeadCodeEliminationOp
305 //===----------------------------------------------------------------------===//
306 
307 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
308  transform::TransformRewriter &rewriter, Operation *target,
309  ApplyToEachResultList &results, transform::TransformState &state) {
310  // Make sure that this transform is not applied to itself. Modifying the
311  // transform IR while it is being interpreted is generally dangerous.
312  DiagnosedSilenceableFailure payloadCheck =
314  if (!payloadCheck.succeeded())
315  return payloadCheck;
316 
317  // Maintain a worklist of potentially dead ops.
318  SetVector<Operation *> worklist;
319 
320  // Helper function that adds all defining ops of used values (operands and
321  // operands of nested ops).
322  auto addDefiningOpsToWorklist = [&](Operation *op) {
323  op->walk([&](Operation *op) {
324  for (Value v : op->getOperands())
325  if (Operation *defOp = v.getDefiningOp())
326  if (target->isProperAncestor(defOp))
327  worklist.insert(defOp);
328  });
329  };
330 
331  // Helper function that erases an op.
332  auto eraseOp = [&](Operation *op) {
333  // Remove op and nested ops from the worklist.
334  op->walk([&](Operation *op) {
335  const auto *it = llvm::find(worklist, op);
336  if (it != worklist.end())
337  worklist.erase(it);
338  });
339  rewriter.eraseOp(op);
340  };
341 
342  // Initial walk over the IR.
343  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
344  if (op != target && isOpTriviallyDead(op)) {
345  addDefiningOpsToWorklist(op);
346  eraseOp(op);
347  }
348  });
349 
350  // Erase all ops that have become dead.
351  while (!worklist.empty()) {
352  Operation *op = worklist.pop_back_val();
353  if (!isOpTriviallyDead(op))
354  continue;
355  addDefiningOpsToWorklist(op);
356  eraseOp(op);
357  }
358 
360 }
361 
362 void transform::ApplyDeadCodeEliminationOp::getEffects(
363  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
364  transform::onlyReadsHandle(getTargetMutable(), effects);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ApplyPatternsOp
370 //===----------------------------------------------------------------------===//
371 
372 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
373  transform::TransformRewriter &rewriter, Operation *target,
374  ApplyToEachResultList &results, transform::TransformState &state) {
375  // Make sure that this transform is not applied to itself. Modifying the
376  // transform IR while it is being interpreted is generally dangerous. Even
377  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
378  // performs many additional simplifications such as dead code elimination.
379  DiagnosedSilenceableFailure payloadCheck =
381  if (!payloadCheck.succeeded())
382  return payloadCheck;
383 
384  // Gather all specified patterns.
385  MLIRContext *ctx = target->getContext();
387  if (!getRegion().empty()) {
388  for (Operation &op : getRegion().front()) {
389  cast<transform::PatternDescriptorOpInterface>(&op)
390  .populatePatternsWithState(patterns, state);
391  }
392  }
393 
394  // Configure the GreedyPatternRewriteDriver.
396  config.listener =
397  static_cast<RewriterBase::Listener *>(rewriter.getListener());
398  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
399 
400  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
402  : getMaxIterations();
403  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
405  : getMaxNumRewrites();
406 
407  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
408  // was requested, apply the greedy pattern rewrite only once. (The greedy
409  // pattern rewrite driver already iterates to a fixpoint internally.)
410  bool cseChanged = false;
411  // One or two iterations should be sufficient. Stop iterating after a certain
412  // threshold to make debugging easier.
413  static const int64_t kNumMaxIterations = 50;
414  int64_t iteration = 0;
415  do {
416  LogicalResult result = failure();
417  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
418  // Op is isolated from above. Apply patterns and also perform region
419  // simplification.
420  result = applyPatternsGreedily(target, frozenPatterns, config);
421  } else {
422  // Manually gather list of ops because the other
423  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
424  // from above. This way, patterns can be applied to ops that are not
425  // isolated from above. Regions are not being simplified. Furthermore,
426  // only a single greedy rewrite iteration is performed.
428  target->walk([&](Operation *nestedOp) {
429  if (target != nestedOp)
430  ops.push_back(nestedOp);
431  });
432  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
433  }
434 
435  // A failure typically indicates that the pattern application did not
436  // converge.
437  if (failed(result)) {
438  return emitSilenceableFailure(target)
439  << "greedy pattern application failed";
440  }
441 
442  if (getApplyCse()) {
443  DominanceInfo domInfo;
444  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
445  &cseChanged);
446  }
447  } while (cseChanged && ++iteration < kNumMaxIterations);
448 
449  if (iteration == kNumMaxIterations)
450  return emitDefiniteFailure() << "fixpoint iteration did not converge";
451 
453 }
454 
455 LogicalResult transform::ApplyPatternsOp::verify() {
456  if (!getRegion().empty()) {
457  for (Operation &op : getRegion().front()) {
458  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
459  InFlightDiagnostic diag = emitOpError()
460  << "expected children ops to implement "
461  "PatternDescriptorOpInterface";
462  diag.attachNote(op.getLoc()) << "op without interface";
463  return diag;
464  }
465  }
466  }
467  return success();
468 }
469 
470 void transform::ApplyPatternsOp::getEffects(
471  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
472  transform::onlyReadsHandle(getTargetMutable(), effects);
474 }
475 
476 void transform::ApplyPatternsOp::build(
477  OpBuilder &builder, OperationState &result, Value target,
478  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
479  result.addOperands(target);
480 
481  OpBuilder::InsertionGuard g(builder);
482  Region *region = result.addRegion();
483  builder.createBlock(region);
484  if (bodyBuilder)
485  bodyBuilder(builder, result.location);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // ApplyCanonicalizationPatternsOp
490 //===----------------------------------------------------------------------===//
491 
492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
494  MLIRContext *ctx = patterns.getContext();
495  for (Dialect *dialect : ctx->getLoadedDialects())
496  dialect->getCanonicalizationPatterns(patterns);
498  op.getCanonicalizationPatterns(patterns, ctx);
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // ApplyConversionPatternsOp
503 //===----------------------------------------------------------------------===//
504 
505 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
508  MLIRContext *ctx = getContext();
509 
510  // Instantiate the default type converter if a type converter builder is
511  // specified.
512  std::unique_ptr<TypeConverter> defaultTypeConverter;
513  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
514  getDefaultTypeConverter();
515  if (typeConverterBuilder)
516  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
517 
518  // Configure conversion target.
519  ConversionTarget conversionTarget(*getContext());
520  if (getLegalOps())
521  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522  conversionTarget.addLegalOp(
523  OperationName(cast<StringAttr>(attr).getValue(), ctx));
524  if (getIllegalOps())
525  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526  conversionTarget.addIllegalOp(
527  OperationName(cast<StringAttr>(attr).getValue(), ctx));
528  if (getLegalDialects())
529  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531  if (getIllegalDialects())
532  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
534 
535  // Gather all specified patterns.
537  // Need to keep the converters alive until after pattern application because
538  // the patterns take a reference to an object that would otherwise get out of
539  // scope.
540  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
541  if (!getPatterns().empty()) {
542  for (Operation &op : getPatterns().front()) {
543  auto descriptor =
544  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
545 
546  // Check if this pattern set specifies a type converter.
547  std::unique_ptr<TypeConverter> typeConverter =
548  descriptor.getTypeConverter();
549  TypeConverter *converter = nullptr;
550  if (typeConverter) {
551  keepAliveConverters.emplace_back(std::move(typeConverter));
552  converter = keepAliveConverters.back().get();
553  } else {
554  // No type converter specified: Use the default type converter.
555  if (!defaultTypeConverter) {
556  auto diag = emitDefiniteFailure()
557  << "pattern descriptor does not specify type "
558  "converter and apply_conversion_patterns op has "
559  "no default type converter";
560  diag.attachNote(op.getLoc()) << "pattern descriptor op";
561  return diag;
562  }
563  converter = defaultTypeConverter.get();
564  }
565 
566  // Add descriptor-specific updates to the conversion target, which may
567  // depend on the final type converter. In structural converters, the
568  // legality of types dictates the dynamic legality of an operation.
569  descriptor.populateConversionTargetRules(*converter, conversionTarget);
570 
571  descriptor.populatePatterns(*converter, patterns);
572  }
573  }
574 
575  // Attach a tracking listener if handles should be preserved. We configure the
576  // listener to allow op replacements with different names, as conversion
577  // patterns typically replace ops with replacement ops that have a different
578  // name.
579  TrackingListenerConfig trackingConfig;
580  trackingConfig.requireMatchingReplacementOpName = false;
581  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
582  ConversionConfig conversionConfig;
583  if (getPreserveHandles())
584  conversionConfig.listener = &trackingListener;
585 
586  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
587  for (Operation *target : state.getPayloadOps(getTarget())) {
588  // Make sure that this transform is not applied to itself. Modifying the
589  // transform IR while it is being interpreted is generally dangerous.
590  DiagnosedSilenceableFailure payloadCheck =
592  if (!payloadCheck.succeeded())
593  return payloadCheck;
594 
595  LogicalResult status = failure();
596  if (getPartialConversion()) {
597  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
598  conversionConfig);
599  } else {
600  status = applyFullConversion(target, conversionTarget, frozenPatterns,
601  conversionConfig);
602  }
603 
604  // Check dialect conversion state.
606  if (failed(status)) {
607  diag = emitSilenceableError() << "dialect conversion failed";
608  diag.attachNote(target->getLoc()) << "target op";
609  }
610 
611  // Check tracking listener error state.
612  DiagnosedSilenceableFailure trackingFailure =
613  trackingListener.checkAndResetError();
614  if (!trackingFailure.succeeded()) {
615  if (diag.succeeded()) {
616  // Tracking failure is the only failure.
617  return trackingFailure;
618  } else {
619  diag.attachNote() << "tracking listener also failed: "
620  << trackingFailure.getMessage();
621  (void)trackingFailure.silence();
622  }
623  }
624 
625  if (!diag.succeeded())
626  return diag;
627  }
628 
630 }
631 
633  if (getNumRegions() != 1 && getNumRegions() != 2)
634  return emitOpError() << "expected 1 or 2 regions";
635  if (!getPatterns().empty()) {
636  for (Operation &op : getPatterns().front()) {
637  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639  emitOpError() << "expected pattern children ops to implement "
640  "ConversionPatternDescriptorOpInterface";
641  diag.attachNote(op.getLoc()) << "op without interface";
642  return diag;
643  }
644  }
645  }
646  if (getNumRegions() == 2) {
647  Region &typeConverterRegion = getRegion(1);
648  if (!llvm::hasSingleElement(typeConverterRegion.front()))
649  return emitOpError()
650  << "expected exactly one op in default type converter region";
651  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
652  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
653  maybeTypeConverter);
654  if (!typeConverterOp) {
655  InFlightDiagnostic diag = emitOpError()
656  << "expected default converter child op to "
657  "implement TypeConverterBuilderOpInterface";
658  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
659  return diag;
660  }
661  // Check default type converter type.
662  if (!getPatterns().empty()) {
663  for (Operation &op : getPatterns().front()) {
664  auto descriptor =
665  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
667  return failure();
668  }
669  }
670  }
671  return success();
672 }
673 
674 void transform::ApplyConversionPatternsOp::getEffects(
675  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676  if (!getPreserveHandles()) {
677  transform::consumesHandle(getTargetMutable(), effects);
678  } else {
679  transform::onlyReadsHandle(getTargetMutable(), effects);
680  }
682 }
683 
684 void transform::ApplyConversionPatternsOp::build(
685  OpBuilder &builder, OperationState &result, Value target,
686  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
687  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
688  result.addOperands(target);
689 
690  {
691  OpBuilder::InsertionGuard g(builder);
692  Region *region1 = result.addRegion();
693  builder.createBlock(region1);
694  if (patternsBodyBuilder)
695  patternsBodyBuilder(builder, result.location);
696  }
697  {
698  OpBuilder::InsertionGuard g(builder);
699  Region *region2 = result.addRegion();
700  builder.createBlock(region2);
701  if (typeConverterBodyBuilder)
702  typeConverterBodyBuilder(builder, result.location);
703  }
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ApplyToLLVMConversionPatternsOp
708 //===----------------------------------------------------------------------===//
709 
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
711  TypeConverter &typeConverter, RewritePatternSet &patterns) {
712  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
713  assert(dialect && "expected that dialect is loaded");
714  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
715  // ConversionTarget is currently ignored because the enclosing
716  // apply_conversion_patterns op sets up its own ConversionTarget.
717  ConversionTarget target(*getContext());
718  iface->populateConvertToLLVMConversionPatterns(
719  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
720 }
721 
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723  transform::TypeConverterBuilderOpInterface builder) {
724  if (builder.getTypeConverterType() != "LLVMTypeConverter")
725  return emitOpError("expected LLVMTypeConverter");
726  return success();
727 }
728 
730  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
731  if (!dialect)
732  return emitOpError("unknown dialect or dialect not loaded: ")
733  << getDialectName();
734  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
735  if (!iface)
736  return emitOpError(
737  "dialect does not implement ConvertToLLVMPatternInterface or "
738  "extension was not loaded: ")
739  << getDialectName();
740  return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // ApplyLoopInvariantCodeMotionOp
745 //===----------------------------------------------------------------------===//
746 
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
749  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
751  transform::TransformState &state) {
752  // Currently, LICM does not remove operations, so we don't need tracking.
753  // If this ever changes, add a LICM entry point that takes a rewriter.
754  moveLoopInvariantCode(target);
756 }
757 
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
760  transform::onlyReadsHandle(getTargetMutable(), effects);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // ApplyRegisteredPassOp
766 //===----------------------------------------------------------------------===//
767 
768 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
769  transform::TransformRewriter &rewriter, Operation *target,
770  ApplyToEachResultList &results, transform::TransformState &state) {
771  // Make sure that this transform is not applied to itself. Modifying the
772  // transform IR while it is being interpreted is generally dangerous. Even
773  // more so when applying passes because they may perform a wide range of IR
774  // modifications.
775  DiagnosedSilenceableFailure payloadCheck =
777  if (!payloadCheck.succeeded())
778  return payloadCheck;
779 
780  // Get pass or pass pipeline from registry.
781  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
782  if (!info)
783  info = PassInfo::lookup(getPassName());
784  if (!info)
785  return emitDefiniteFailure()
786  << "unknown pass or pass pipeline: " << getPassName();
787 
788  // Create pass manager and run the pass or pass pipeline.
789  PassManager pm(getContext());
790  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
791  emitError(msg);
792  return failure();
793  }))) {
794  return emitDefiniteFailure()
795  << "failed to add pass or pass pipeline to pipeline: "
796  << getPassName();
797  }
798  if (failed(pm.run(target))) {
799  auto diag = emitSilenceableError() << "pass pipeline failed";
800  diag.attachNote(target->getLoc()) << "target op";
801  return diag;
802  }
803 
804  results.push_back(target);
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // CastOp
810 //===----------------------------------------------------------------------===//
811 
813 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
814  Operation *target, ApplyToEachResultList &results,
815  transform::TransformState &state) {
816  results.push_back(target);
818 }
819 
820 void transform::CastOp::getEffects(
821  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
822  onlyReadsPayload(effects);
823  onlyReadsHandle(getInputMutable(), effects);
824  producesHandle(getOperation()->getOpResults(), effects);
825 }
826 
827 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
828  assert(inputs.size() == 1 && "expected one input");
829  assert(outputs.size() == 1 && "expected one output");
830  return llvm::all_of(
831  std::initializer_list<Type>{inputs.front(), outputs.front()},
832  llvm::IsaPred<transform::TransformHandleTypeInterface>);
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // CollectMatchingOp
837 //===----------------------------------------------------------------------===//
838 
839 /// Applies matcher operations from the given `block` using
840 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
841 /// accordingly. If any of the matcher produces a silenceable failure, discards
842 /// it (printing the content to the debug output stream) and returns failure. If
843 /// any of the matchers produces a definite failure, reports it and returns
844 /// failure. If all matchers in the block succeed, populates `mappings` with the
845 /// payload entities associated with the block terminator operands. Note that
846 /// `mappings` will be cleared before that.
849  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
851  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
852  assert(block.getParent() && "cannot match using a detached block");
853  auto matchScope = state.make_region_scope(*block.getParent());
854  if (failed(
855  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
857 
858  for (Operation &match : block.without_terminator()) {
859  if (!isa<transform::MatchOpInterface>(match)) {
860  return emitDefiniteFailure(match.getLoc())
861  << "expected operations in the match part to "
862  "implement MatchOpInterface";
863  }
865  state.applyTransform(cast<transform::TransformOpInterface>(match));
866  if (diag.succeeded())
867  continue;
868 
869  return diag;
870  }
871 
872  // Remember the values mapped to the terminator operands so we can
873  // forward them to the action.
874  ValueRange yieldedValues = block.getTerminator()->getOperands();
875  // Our contract with the caller is that the mappings will contain only the
876  // newly mapped values, clear the rest.
877  mappings.clear();
878  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
880 }
881 
882 /// Returns `true` if both types implement one of the interfaces provided as
883 /// template parameters.
884 template <typename... Tys>
885 static bool implementSameInterface(Type t1, Type t2) {
886  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
887 }
888 
889 /// Returns `true` if both types implement one of the transform dialect
890 /// interfaces.
892  return implementSameInterface<transform::TransformHandleTypeInterface,
893  transform::TransformParamTypeInterface,
894  transform::TransformValueHandleTypeInterface>(
895  t1, t2);
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // CollectMatchingOp
900 //===----------------------------------------------------------------------===//
901 
903 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
905  transform::TransformState &state) {
906  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907  getOperation(), getMatcher());
908  if (matcher.isExternal()) {
909  return emitDefiniteFailure()
910  << "unresolved external symbol " << getMatcher();
911  }
912 
914  rawResults.resize(getOperation()->getNumResults());
915  std::optional<DiagnosedSilenceableFailure> maybeFailure;
916  for (Operation *root : state.getPayloadOps(getRoot())) {
917  WalkResult walkResult = root->walk([&](Operation *op) {
918  DEBUG_MATCHER({
919  DBGS_MATCHER() << "matching ";
920  op->print(llvm::dbgs(),
921  OpPrintingFlags().assumeVerified().skipRegions());
922  llvm::dbgs() << " @" << op << "\n";
923  });
924 
925  // Try matching.
927  SmallVector<transform::MappedValue> inputMapping({op});
929  matcher.getFunctionBody().front(),
930  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
931  mappings);
932  if (diag.isDefiniteFailure())
933  return WalkResult::interrupt();
934  if (diag.isSilenceableFailure()) {
935  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
936  << " failed: " << diag.getMessage());
937  return WalkResult::advance();
938  }
939 
940  // If succeeded, collect results.
941  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
942  if (mapping.size() != 1) {
943  maybeFailure.emplace(emitSilenceableError()
944  << "result #" << i << ", associated with "
945  << mapping.size()
946  << " payload objects, expected 1");
947  return WalkResult::interrupt();
948  }
949  rawResults[i].push_back(mapping[0]);
950  }
951  return WalkResult::advance();
952  });
953  if (walkResult.wasInterrupted())
954  return std::move(*maybeFailure);
955  assert(!maybeFailure && "failure set but the walk was not interrupted");
956 
957  for (auto &&[opResult, rawResult] :
958  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
959  results.setMappedValues(opResult, rawResult);
960  }
961  }
963 }
964 
965 void transform::CollectMatchingOp::getEffects(
966  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
967  onlyReadsHandle(getRootMutable(), effects);
968  producesHandle(getOperation()->getOpResults(), effects);
969  onlyReadsPayload(effects);
970 }
971 
972 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
973  SymbolTableCollection &symbolTable) {
974  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
975  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
976  if (!matcherSymbol ||
977  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978  return emitError() << "unresolved matcher symbol " << getMatcher();
979 
980  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
981  if (argumentTypes.size() != 1 ||
982  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
983  return emitError()
984  << "expected the matcher to take one operation handle argument";
985  }
986  if (!matcherSymbol.getArgAttr(
987  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988  return emitError() << "expected the matcher argument to be marked readonly";
989  }
990 
991  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
992  if (resultTypes.size() != getOperation()->getNumResults()) {
993  return emitError()
994  << "expected the matcher to yield as many values as op has results ("
995  << getOperation()->getNumResults() << "), got "
996  << resultTypes.size();
997  }
998 
999  for (auto &&[i, matcherType, resultType] :
1000  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1001  if (implementSameTransformInterface(matcherType, resultType))
1002  continue;
1003 
1004  return emitError()
1005  << "mismatching type interfaces for matcher result and op result #"
1006  << i;
1007  }
1008 
1009  return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // ForeachMatchOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 // This is fine because nothing is actually consumed by this op.
1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1018 
1020 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1021  transform::TransformResults &results,
1022  transform::TransformState &state) {
1024  matchActionPairs;
1025  matchActionPairs.reserve(getMatchers().size());
1026  SymbolTableCollection symbolTable;
1027  for (auto &&[matcher, action] :
1028  llvm::zip_equal(getMatchers(), getActions())) {
1029  auto matcherSymbol =
1030  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1031  getOperation(), cast<SymbolRefAttr>(matcher));
1032  auto actionSymbol =
1033  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1034  getOperation(), cast<SymbolRefAttr>(action));
1035  assert(matcherSymbol && actionSymbol &&
1036  "unresolved symbols not caught by the verifier");
1037 
1038  if (matcherSymbol.isExternal())
1039  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1040  if (actionSymbol.isExternal())
1041  return emitDefiniteFailure() << "unresolved external symbol " << action;
1042 
1043  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1044  }
1045 
1046  DiagnosedSilenceableFailure overallDiag =
1048 
1049  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1050  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1051  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1052  // Explicitly add the mapping for the first block argument (the op being
1053  // matched).
1054  matchInputMapping.emplace_back();
1055  transform::detail::prepareValueMappings(matchInputMapping,
1056  getForwardedInputs(), state);
1057  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1058  actionResultMapping.resize(getForwardedOutputs().size());
1059 
1060  for (Operation *root : state.getPayloadOps(getRoot())) {
1061  WalkResult walkResult = root->walk([&](Operation *op) {
1062  // If getRestrictRoot is not present, skip over the root op itself so we
1063  // don't invalidate it.
1064  if (!getRestrictRoot() && op == root)
1065  return WalkResult::advance();
1066 
1067  DEBUG_MATCHER({
1068  DBGS_MATCHER() << "matching ";
1069  op->print(llvm::dbgs(),
1070  OpPrintingFlags().assumeVerified().skipRegions());
1071  llvm::dbgs() << " @" << op << "\n";
1072  });
1073 
1074  firstMatchArgument.clear();
1075  firstMatchArgument.push_back(op);
1076 
1077  // Try all the match/action pairs until the first successful match.
1078  for (auto [matcher, action] : matchActionPairs) {
1080  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081  state, matchOutputMapping);
1082  if (diag.isDefiniteFailure())
1083  return WalkResult::interrupt();
1084  if (diag.isSilenceableFailure()) {
1085  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1086  << " failed: " << diag.getMessage());
1087  continue;
1088  }
1089 
1090  auto scope = state.make_region_scope(action.getFunctionBody());
1091  if (failed(state.mapBlockArguments(
1092  action.getFunctionBody().front().getArguments(),
1093  matchOutputMapping))) {
1094  return WalkResult::interrupt();
1095  }
1096 
1097  for (Operation &transform :
1098  action.getFunctionBody().front().without_terminator()) {
1100  state.applyTransform(cast<TransformOpInterface>(transform));
1101  if (result.isDefiniteFailure())
1102  return WalkResult::interrupt();
1103  if (result.isSilenceableFailure()) {
1104  if (overallDiag.succeeded()) {
1105  overallDiag = emitSilenceableError() << "actions failed";
1106  }
1107  overallDiag.attachNote(action->getLoc())
1108  << "failed action: " << result.getMessage();
1109  overallDiag.attachNote(op->getLoc())
1110  << "when applied to this matching payload";
1111  (void)result.silence();
1112  continue;
1113  }
1114  }
1115  if (failed(detail::appendValueMappings(
1116  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1117  action.getFunctionBody().front().getTerminator()->getOperands(),
1118  state, getFlattenResults()))) {
1120  << "action @" << action.getName()
1121  << " has results associated with multiple payload entities, "
1122  "but flattening was not requested";
1123  return WalkResult::interrupt();
1124  }
1125  break;
1126  }
1127  return WalkResult::advance();
1128  });
1129  if (walkResult.wasInterrupted())
1131  }
1132 
1133  // The root operation should not have been affected, so we can just reassign
1134  // the payload to the result. Note that we need to consume the root handle to
1135  // make sure any handles to operations inside, that could have been affected
1136  // by actions, are invalidated.
1137  results.set(llvm::cast<OpResult>(getUpdated()),
1138  state.getPayloadOps(getRoot()));
1139  for (auto &&[result, mapping] :
1140  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1141  results.setMappedValues(result, mapping);
1142  }
1143  return overallDiag;
1144 }
1145 
1146 void transform::ForeachMatchOp::getAsmResultNames(
1147  OpAsmSetValueNameFn setNameFn) {
1148  setNameFn(getUpdated(), "updated_root");
1149  for (Value v : getForwardedOutputs()) {
1150  setNameFn(v, "yielded");
1151  }
1152 }
1153 
1154 void transform::ForeachMatchOp::getEffects(
1155  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1156  // Bail if invalid.
1157  if (getOperation()->getNumOperands() < 1 ||
1158  getOperation()->getNumResults() < 1) {
1159  return modifiesPayload(effects);
1160  }
1161 
1162  consumesHandle(getRootMutable(), effects);
1163  onlyReadsHandle(getForwardedInputsMutable(), effects);
1164  producesHandle(getOperation()->getOpResults(), effects);
1165  modifiesPayload(effects);
1166 }
1167 
1168 /// Parses the comma-separated list of symbol reference pairs of the format
1169 /// `@matcher -> @action`.
1170 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1171  ArrayAttr &matchers,
1172  ArrayAttr &actions) {
1173  StringAttr matcher;
1174  StringAttr action;
1175  SmallVector<Attribute> matcherList;
1176  SmallVector<Attribute> actionList;
1177  do {
1178  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1179  parser.parseSymbolName(action)) {
1180  return failure();
1181  }
1182  matcherList.push_back(SymbolRefAttr::get(matcher));
1183  actionList.push_back(SymbolRefAttr::get(action));
1184  } while (parser.parseOptionalComma().succeeded());
1185 
1186  matchers = parser.getBuilder().getArrayAttr(matcherList);
1187  actions = parser.getBuilder().getArrayAttr(actionList);
1188  return success();
1189 }
1190 
1191 /// Prints the comma-separated list of symbol reference pairs of the format
1192 /// `@matcher -> @action`.
1194  ArrayAttr matchers, ArrayAttr actions) {
1195  printer.increaseIndent();
1196  printer.increaseIndent();
1197  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1198  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1199  printer.printNewline();
1200  printer << cast<SymbolRefAttr>(matcher) << " -> "
1201  << cast<SymbolRefAttr>(action);
1202  if (idx != matchers.size() - 1)
1203  printer << ", ";
1204  }
1205  printer.decreaseIndent();
1206  printer.decreaseIndent();
1207 }
1208 
1209 LogicalResult transform::ForeachMatchOp::verify() {
1210  if (getMatchers().size() != getActions().size())
1211  return emitOpError() << "expected the same number of matchers and actions";
1212  if (getMatchers().empty())
1213  return emitOpError() << "expected at least one match/action pair";
1214 
1215  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1216  for (Attribute name : getMatchers()) {
1217  if (matcherNames.insert(name).second)
1218  continue;
1219  emitWarning() << "matcher " << name
1220  << " is used more than once, only the first match will apply";
1221  }
1222 
1223  return success();
1224 }
1225 
1226 /// Checks that the attributes of the function-like operation have correct
1227 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1228 /// annotations being present even if they can be inferred from the body.
1230 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1231  bool alsoVerifyInternal = false) {
1232  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1233  llvm::SmallDenseSet<unsigned> consumedArguments;
1234  if (!op.isExternal()) {
1235  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1236  consumedArguments);
1237  }
1238  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1239  bool isConsumed =
1240  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1241  nullptr;
1242  bool isReadOnly =
1243  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1244  nullptr;
1245  if (isConsumed && isReadOnly) {
1246  return transformOp.emitSilenceableError()
1247  << "argument #" << i << " cannot be both readonly and consumed";
1248  }
1249  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1250  return transformOp.emitSilenceableError()
1251  << "must provide consumed/readonly status for arguments of "
1252  "external or called ops";
1253  }
1254  if (op.isExternal())
1255  continue;
1256 
1257  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1258  return transformOp.emitSilenceableError()
1259  << "argument #" << i
1260  << " is consumed in the body but is not marked as such";
1261  }
1262  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1263  // Cannot use op.emitWarning() here as it would attempt to verify the op
1264  // before printing, resulting in infinite recursion.
1265  emitWarning(op->getLoc())
1266  << "op argument #" << i
1267  << " is not consumed in the body but is marked as consumed";
1268  }
1269  }
1271 }
1272 
1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1274  SymbolTableCollection &symbolTable) {
1275  assert(getMatchers().size() == getActions().size());
1276  auto consumedAttr =
1277  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1278  for (auto &&[matcher, action] :
1279  llvm::zip_equal(getMatchers(), getActions())) {
1280  // Presence and typing.
1281  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1282  symbolTable.lookupNearestSymbolFrom(getOperation(),
1283  cast<SymbolRefAttr>(matcher)));
1284  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1285  symbolTable.lookupNearestSymbolFrom(getOperation(),
1286  cast<SymbolRefAttr>(action)));
1287  if (!matcherSymbol ||
1288  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1289  return emitError() << "unresolved matcher symbol " << matcher;
1290  if (!actionSymbol ||
1291  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1292  return emitError() << "unresolved action symbol " << action;
1293 
1294  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1295  /*emitWarnings=*/false,
1296  /*alsoVerifyInternal=*/true)
1297  .checkAndReport())) {
1298  return failure();
1299  }
1300  if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1301  /*emitWarnings=*/false,
1302  /*alsoVerifyInternal=*/true)
1303  .checkAndReport())) {
1304  return failure();
1305  }
1306 
1307  // Input -> matcher forwarding.
1308  TypeRange operandTypes = getOperandTypes();
1309  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310  if (operandTypes.size() != matcherArguments.size()) {
1312  emitError() << "the number of operands (" << operandTypes.size()
1313  << ") doesn't match the number of matcher arguments ("
1314  << matcherArguments.size() << ") for " << matcher;
1315  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1316  return diag;
1317  }
1318  for (auto &&[i, operand, argument] :
1319  llvm::enumerate(operandTypes, matcherArguments)) {
1320  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1322  emitOpError()
1323  << "does not expect matcher symbol to consume its operand #" << i;
1324  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1325  return diag;
1326  }
1327 
1328  if (implementSameTransformInterface(operand, argument))
1329  continue;
1330 
1332  emitError()
1333  << "mismatching type interfaces for operand and matcher argument #"
1334  << i << " of matcher " << matcher;
1335  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1336  return diag;
1337  }
1338 
1339  // Matcher -> action forwarding.
1340  TypeRange matcherResults = matcherSymbol.getResultTypes();
1341  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1342  if (matcherResults.size() != actionArguments.size()) {
1343  return emitError() << "mismatching number of matcher results and "
1344  "action arguments between "
1345  << matcher << " (" << matcherResults.size() << ") and "
1346  << action << " (" << actionArguments.size() << ")";
1347  }
1348  for (auto &&[i, matcherType, actionType] :
1349  llvm::enumerate(matcherResults, actionArguments)) {
1350  if (implementSameTransformInterface(matcherType, actionType))
1351  continue;
1352 
1353  return emitError() << "mismatching type interfaces for matcher result "
1354  "and action argument #"
1355  << i << "of matcher " << matcher << " and action "
1356  << action;
1357  }
1358 
1359  // Action -> result forwarding.
1360  TypeRange actionResults = actionSymbol.getResultTypes();
1361  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1362  if (actionResults.size() != resultTypes.size()) {
1364  emitError() << "the number of action results ("
1365  << actionResults.size() << ") for " << action
1366  << " doesn't match the number of extra op results ("
1367  << resultTypes.size() << ")";
1368  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1369  return diag;
1370  }
1371  for (auto &&[i, resultType, actionType] :
1372  llvm::enumerate(resultTypes, actionResults)) {
1373  if (implementSameTransformInterface(resultType, actionType))
1374  continue;
1375 
1377  emitError() << "mismatching type interfaces for action result #" << i
1378  << " of action " << action << " and op result";
1379  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1380  return diag;
1381  }
1382  }
1383  return success();
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // ForeachOp
1388 //===----------------------------------------------------------------------===//
1389 
1391 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1392  transform::TransformResults &results,
1393  transform::TransformState &state) {
1394  // We store the payloads before executing the body as ops may be removed from
1395  // the mapping by the TrackingRewriter while iteration is in progress.
1397  detail::prepareValueMappings(payloads, getTargets(), state);
1398  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399  bool withZipShortest = getWithZipShortest();
1400 
1401  // In case of `zip_shortest`, set the number of iterations to the
1402  // smallest payload in the targets.
1403  if (withZipShortest) {
1404  numIterations =
1405  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1406  const SmallVector<MappedValue> &B) {
1407  return A.size() < B.size();
1408  })->size();
1409 
1410  for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411  payloads[argIdx].resize(numIterations);
1412  }
1413 
1414  // As we will be "zipping" over them, check all payloads have the same size.
1415  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1416  // when true.
1417  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1418  argIdx++) {
1419  if (payloads[argIdx].size() != numIterations) {
1420  return emitSilenceableError()
1421  << "prior targets' payload size (" << numIterations
1422  << ") differs from payload size (" << payloads[argIdx].size()
1423  << ") of target " << getTargets()[argIdx];
1424  }
1425  }
1426 
1427  // Start iterating, indexing into payloads to obtain the right arguments to
1428  // call the body with - each slice of payloads at the same argument index
1429  // corresponding to a tuple to use as the body's block arguments.
1430  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1431  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1432  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433  auto scope = state.make_region_scope(getBody());
1434  // Set up arguments to the region's block.
1435  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1436  MappedValue argument = payloads[argIdx][iterIdx];
1437  // Note that each blockArg's handle gets associated with just a single
1438  // element from the corresponding target's payload.
1439  if (failed(state.mapBlockArgument(blockArg, {argument})))
1441  }
1442 
1443  // Execute loop body.
1444  for (Operation &transform : getBody().front().without_terminator()) {
1445  DiagnosedSilenceableFailure result = state.applyTransform(
1446  llvm::cast<transform::TransformOpInterface>(transform));
1447  if (!result.succeeded())
1448  return result;
1449  }
1450 
1451  // Append yielded payloads to corresponding results from prior iterations.
1452  OperandRange yieldOperands = getYieldOp().getOperands();
1453  for (auto &&[result, yieldOperand, resTuple] :
1454  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1455  // NB: each iteration we add any number of ops/vals/params to a result.
1456  if (isa<TransformHandleTypeInterface>(result.getType()))
1457  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460  else if (isa<TransformParamTypeInterface>(result.getType()))
1461  llvm::append_range(resTuple, state.getParams(yieldOperand));
1462  else
1463  assert(false && "unhandled handle type");
1464  }
1465 
1466  // Associate the accumulated result payloads to the op's actual results.
1467  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1468  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1469 
1471 }
1472 
1473 void transform::ForeachOp::getEffects(
1474  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1475  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1476  // arity errors, this method might get called before/in absence of `verify()`.
1477  for (auto &&[target, blockArg] :
1478  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1479  BlockArgument blockArgument = blockArg;
1480  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1481  return isHandleConsumed(blockArgument,
1482  cast<TransformOpInterface>(&op));
1483  })) {
1484  consumesHandle(target, effects);
1485  } else {
1486  onlyReadsHandle(target, effects);
1487  }
1488  }
1489 
1490  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1491  return doesModifyPayload(cast<TransformOpInterface>(&op));
1492  })) {
1493  modifiesPayload(effects);
1494  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1495  return doesReadPayload(cast<TransformOpInterface>(&op));
1496  })) {
1497  onlyReadsPayload(effects);
1498  }
1499 
1500  producesHandle(getOperation()->getOpResults(), effects);
1501 }
1502 
1503 void transform::ForeachOp::getSuccessorRegions(
1504  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1505  Region *bodyRegion = &getBody();
1506  if (point.isParent()) {
1507  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1508  return;
1509  }
1510 
1511  // Branch back to the region or the parent.
1512  assert(point == getBody() && "unexpected region index");
1513  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1514  regions.emplace_back();
1515 }
1516 
1518 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1519  // Each block argument handle is mapped to a subset (one op to be precise)
1520  // of the payload of the corresponding `targets` operand of ForeachOp.
1521  assert(point == getBody() && "unexpected region index");
1522  return getOperation()->getOperands();
1523 }
1524 
1525 transform::YieldOp transform::ForeachOp::getYieldOp() {
1526  return cast<transform::YieldOp>(getBody().front().getTerminator());
1527 }
1528 
1529 LogicalResult transform::ForeachOp::verify() {
1530  for (auto [targetOpt, bodyArgOpt] :
1531  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532  if (!targetOpt || !bodyArgOpt)
1533  return emitOpError() << "expects the same number of targets as the body "
1534  "has block arguments";
1535  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1536  return emitOpError(
1537  "expects co-indexed targets and the body's "
1538  "block arguments to have the same op/value/param type");
1539  }
1540 
1541  for (auto [resultOpt, yieldOperandOpt] :
1542  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543  if (!resultOpt || !yieldOperandOpt)
1544  return emitOpError() << "expects the same number of results as the "
1545  "yield terminator has operands";
1546  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547  return emitOpError("expects co-indexed results and yield "
1548  "operands to have the same op/value/param type");
1549  }
1550 
1551  return success();
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // GetParentOp
1556 //===----------------------------------------------------------------------===//
1557 
1559 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1560  transform::TransformResults &results,
1561  transform::TransformState &state) {
1562  SmallVector<Operation *> parents;
1563  DenseSet<Operation *> resultSet;
1564  for (Operation *target : state.getPayloadOps(getTarget())) {
1565  Operation *parent = target;
1566  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1567  parent = parent->getParentOp();
1568  while (parent) {
1569  bool checkIsolatedFromAbove =
1570  !getIsolatedFromAbove() ||
1572  bool checkOpName = !getOpName().has_value() ||
1573  parent->getName().getStringRef() == *getOpName();
1574  if (checkIsolatedFromAbove && checkOpName)
1575  break;
1576  parent = parent->getParentOp();
1577  }
1578  if (!parent) {
1579  if (getAllowEmptyResults()) {
1580  results.set(llvm::cast<OpResult>(getResult()), parents);
1582  }
1584  emitSilenceableError()
1585  << "could not find a parent op that matches all requirements";
1586  diag.attachNote(target->getLoc()) << "target op";
1587  return diag;
1588  }
1589  }
1590  if (getDeduplicate()) {
1591  if (resultSet.insert(parent).second)
1592  parents.push_back(parent);
1593  } else {
1594  parents.push_back(parent);
1595  }
1596  }
1597  results.set(llvm::cast<OpResult>(getResult()), parents);
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // GetConsumersOfResult
1603 //===----------------------------------------------------------------------===//
1604 
1606 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1607  transform::TransformResults &results,
1608  transform::TransformState &state) {
1609  int64_t resultNumber = getResultNumber();
1610  auto payloadOps = state.getPayloadOps(getTarget());
1611  if (std::empty(payloadOps)) {
1612  results.set(cast<OpResult>(getResult()), {});
1614  }
1615  if (!llvm::hasSingleElement(payloadOps))
1616  return emitDefiniteFailure()
1617  << "handle must be mapped to exactly one payload op";
1618 
1619  Operation *target = *payloadOps.begin();
1620  if (target->getNumResults() <= resultNumber)
1621  return emitDefiniteFailure() << "result number overflow";
1622  results.set(llvm::cast<OpResult>(getResult()),
1623  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // GetDefiningOp
1629 //===----------------------------------------------------------------------===//
1630 
1632 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1633  transform::TransformResults &results,
1634  transform::TransformState &state) {
1635  SmallVector<Operation *> definingOps;
1636  for (Value v : state.getPayloadValues(getTarget())) {
1637  if (llvm::isa<BlockArgument>(v)) {
1639  emitSilenceableError() << "cannot get defining op of block argument";
1640  diag.attachNote(v.getLoc()) << "target value";
1641  return diag;
1642  }
1643  definingOps.push_back(v.getDefiningOp());
1644  }
1645  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1647 }
1648 
1649 //===----------------------------------------------------------------------===//
1650 // GetProducerOfOperand
1651 //===----------------------------------------------------------------------===//
1652 
1654 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1655  transform::TransformResults &results,
1656  transform::TransformState &state) {
1657  int64_t operandNumber = getOperandNumber();
1658  SmallVector<Operation *> producers;
1659  for (Operation *target : state.getPayloadOps(getTarget())) {
1660  Operation *producer =
1661  target->getNumOperands() <= operandNumber
1662  ? nullptr
1663  : target->getOperand(operandNumber).getDefiningOp();
1664  if (!producer) {
1666  emitSilenceableError()
1667  << "could not find a producer for operand number: " << operandNumber
1668  << " of " << *target;
1669  diag.attachNote(target->getLoc()) << "target op";
1670  return diag;
1671  }
1672  producers.push_back(producer);
1673  }
1674  results.set(llvm::cast<OpResult>(getResult()), producers);
1676 }
1677 
1678 //===----------------------------------------------------------------------===//
1679 // GetOperandOp
1680 //===----------------------------------------------------------------------===//
1681 
1683 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1684  transform::TransformResults &results,
1685  transform::TransformState &state) {
1686  SmallVector<Value> operands;
1687  for (Operation *target : state.getPayloadOps(getTarget())) {
1688  SmallVector<int64_t> operandPositions;
1690  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1691  target->getNumOperands(), operandPositions);
1692  if (diag.isSilenceableFailure()) {
1693  diag.attachNote(target->getLoc())
1694  << "while considering positions of this payload operation";
1695  return diag;
1696  }
1697  llvm::append_range(operands,
1698  llvm::map_range(operandPositions, [&](int64_t pos) {
1699  return target->getOperand(pos);
1700  }));
1701  }
1702  results.setValues(cast<OpResult>(getResult()), operands);
1704 }
1705 
1706 LogicalResult transform::GetOperandOp::verify() {
1707  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1708  getIsInverted(), getIsAll());
1709 }
1710 
1711 //===----------------------------------------------------------------------===//
1712 // GetResultOp
1713 //===----------------------------------------------------------------------===//
1714 
1716 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1717  transform::TransformResults &results,
1718  transform::TransformState &state) {
1719  SmallVector<Value> opResults;
1720  for (Operation *target : state.getPayloadOps(getTarget())) {
1721  SmallVector<int64_t> resultPositions;
1723  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1724  target->getNumResults(), resultPositions);
1725  if (diag.isSilenceableFailure()) {
1726  diag.attachNote(target->getLoc())
1727  << "while considering positions of this payload operation";
1728  return diag;
1729  }
1730  llvm::append_range(opResults,
1731  llvm::map_range(resultPositions, [&](int64_t pos) {
1732  return target->getResult(pos);
1733  }));
1734  }
1735  results.setValues(cast<OpResult>(getResult()), opResults);
1737 }
1738 
1739 LogicalResult transform::GetResultOp::verify() {
1740  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1741  getIsInverted(), getIsAll());
1742 }
1743 
1744 //===----------------------------------------------------------------------===//
1745 // GetTypeOp
1746 //===----------------------------------------------------------------------===//
1747 
1748 void transform::GetTypeOp::getEffects(
1749  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1750  onlyReadsHandle(getValueMutable(), effects);
1751  producesHandle(getOperation()->getOpResults(), effects);
1752  onlyReadsPayload(effects);
1753 }
1754 
1756 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1757  transform::TransformResults &results,
1758  transform::TransformState &state) {
1759  SmallVector<Attribute> params;
1760  for (Value value : state.getPayloadValues(getValue())) {
1761  Type type = value.getType();
1762  if (getElemental()) {
1763  if (auto shaped = dyn_cast<ShapedType>(type)) {
1764  type = shaped.getElementType();
1765  }
1766  }
1767  params.push_back(TypeAttr::get(type));
1768  }
1769  results.setParams(cast<OpResult>(getResult()), params);
1771 }
1772 
1773 //===----------------------------------------------------------------------===//
1774 // IncludeOp
1775 //===----------------------------------------------------------------------===//
1776 
1777 /// Applies the transform ops contained in `block`. Maps `results` to the same
1778 /// values as the operands of the block terminator.
1780 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1782  transform::TransformResults &results) {
1783  // Apply the sequenced ops one by one.
1784  for (Operation &transform : block.without_terminator()) {
1786  state.applyTransform(cast<transform::TransformOpInterface>(transform));
1787  if (result.isDefiniteFailure())
1788  return result;
1789 
1790  if (result.isSilenceableFailure()) {
1791  if (mode == transform::FailurePropagationMode::Propagate) {
1792  // Propagate empty results in case of early exit.
1793  forwardEmptyOperands(&block, state, results);
1794  return result;
1795  }
1796  (void)result.silence();
1797  }
1798  }
1799 
1800  // Forward the operation mapping for values yielded from the sequence to the
1801  // values produced by the sequence op.
1802  transform::detail::forwardTerminatorOperands(&block, state, results);
1804 }
1805 
1807 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1808  transform::TransformResults &results,
1809  transform::TransformState &state) {
1810  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1811  getOperation(), getTarget());
1812  assert(callee && "unverified reference to unknown symbol");
1813 
1814  if (callee.isExternal())
1815  return emitDefiniteFailure() << "unresolved external named sequence";
1816 
1817  // Map operands to block arguments.
1819  detail::prepareValueMappings(mappings, getOperands(), state);
1820  auto scope = state.make_region_scope(callee.getBody());
1821  for (auto &&[arg, map] :
1822  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1823  if (failed(state.mapBlockArgument(arg, map)))
1825  }
1826 
1828  callee.getBody().front(), getFailurePropagationMode(), state, results);
1829  mappings.clear();
1831  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1832  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1833  results.setMappedValues(result, mapping);
1834  return result;
1835 }
1836 
1838 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1839 
1840 void transform::IncludeOp::getEffects(
1841  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1842  // Always mark as modifying the payload.
1843  // TODO: a mechanism to annotate effects on payload. Even when all handles are
1844  // only read, the payload may still be modified, so we currently stay on the
1845  // conservative side and always indicate modification. This may prevent some
1846  // code reordering.
1847  modifiesPayload(effects);
1848 
1849  // Results are always produced.
1850  producesHandle(getOperation()->getOpResults(), effects);
1851 
1852  // Adds default effects to operands and results. This will be added if
1853  // preconditions fail so the trait verifier doesn't complain about missing
1854  // effects and the real precondition failure is reported later on.
1855  auto defaultEffects = [&] {
1856  onlyReadsHandle(getOperation()->getOpOperands(), effects);
1857  };
1858 
1859  // Bail if the callee is unknown. This may run as part of the verification
1860  // process before we verified the validity of the callee or of this op.
1861  auto target =
1862  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1863  if (!target)
1864  return defaultEffects();
1865  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1866  getOperation(), getTarget());
1867  if (!callee)
1868  return defaultEffects();
1869  DiagnosedSilenceableFailure earlyVerifierResult =
1870  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1871  if (!earlyVerifierResult.succeeded()) {
1872  (void)earlyVerifierResult.silence();
1873  return defaultEffects();
1874  }
1875 
1876  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1877  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1878  consumesHandle(getOperation()->getOpOperand(i), effects);
1879  else
1880  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
1881  }
1882 }
1883 
1884 LogicalResult
1885 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1886  // Access through indirection and do additional checking because this may be
1887  // running before the main op verifier.
1888  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1889  if (!targetAttr)
1890  return emitOpError() << "expects a 'target' symbol reference attribute";
1891 
1892  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1893  *this, targetAttr);
1894  if (!target)
1895  return emitOpError() << "does not reference a named transform sequence";
1896 
1897  FunctionType fnType = target.getFunctionType();
1898  if (fnType.getNumInputs() != getNumOperands())
1899  return emitError("incorrect number of operands for callee");
1900 
1901  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1902  if (getOperand(i).getType() != fnType.getInput(i)) {
1903  return emitOpError("operand type mismatch: expected operand type ")
1904  << fnType.getInput(i) << ", but provided "
1905  << getOperand(i).getType() << " for operand number " << i;
1906  }
1907  }
1908 
1909  if (fnType.getNumResults() != getNumResults())
1910  return emitError("incorrect number of results for callee");
1911 
1912  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1913  Type resultType = getResult(i).getType();
1914  Type funcType = fnType.getResult(i);
1915  if (!implementSameTransformInterface(resultType, funcType)) {
1916  return emitOpError() << "type of result #" << i
1917  << " must implement the same transform dialect "
1918  "interface as the corresponding callee result";
1919  }
1920  }
1921 
1923  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1924  /*alsoVerifyInternal=*/true)
1925  .checkAndReport();
1926 }
1927 
1928 //===----------------------------------------------------------------------===//
1929 // MatchOperationEmptyOp
1930 //===----------------------------------------------------------------------===//
1931 
1932 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1933  ::std::optional<::mlir::Operation *> maybeCurrent,
1935  if (!maybeCurrent.has_value()) {
1936  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1938  }
1939  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1940  return emitSilenceableError() << "operation is not empty";
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // MatchOperationNameOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1948  Operation *current, transform::TransformResults &results,
1949  transform::TransformState &state) {
1950  StringRef currentOpName = current->getName().getStringRef();
1951  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1952  if (acceptedAttr.getValue() == currentOpName)
1954  }
1955  return emitSilenceableError() << "wrong operation name";
1956 }
1957 
1958 //===----------------------------------------------------------------------===//
1959 // MatchParamCmpIOp
1960 //===----------------------------------------------------------------------===//
1961 
1963 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1964  transform::TransformResults &results,
1965  transform::TransformState &state) {
1966  auto signedAPIntAsString = [&](const APInt &value) {
1967  std::string str;
1968  llvm::raw_string_ostream os(str);
1969  value.print(os, /*isSigned=*/true);
1970  return str;
1971  };
1972 
1973  ArrayRef<Attribute> params = state.getParams(getParam());
1974  ArrayRef<Attribute> references = state.getParams(getReference());
1975 
1976  if (params.size() != references.size()) {
1977  return emitSilenceableError()
1978  << "parameters have different payload lengths (" << params.size()
1979  << " vs " << references.size() << ")";
1980  }
1981 
1982  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1983  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1984  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1985  if (!intAttr || !refAttr) {
1986  return emitDefiniteFailure()
1987  << "non-integer parameter value not expected";
1988  }
1989  if (intAttr.getType() != refAttr.getType()) {
1990  return emitDefiniteFailure()
1991  << "mismatching integer attribute types in parameter #" << i;
1992  }
1993  APInt value = intAttr.getValue();
1994  APInt refValue = refAttr.getValue();
1995 
1996  // TODO: this copy will not be necessary in C++20.
1997  int64_t position = i;
1998  auto reportError = [&](StringRef direction) {
2000  emitSilenceableError() << "expected parameter to be " << direction
2001  << " " << signedAPIntAsString(refValue)
2002  << ", got " << signedAPIntAsString(value);
2003  diag.attachNote(getParam().getLoc())
2004  << "value # " << position
2005  << " associated with the parameter defined here";
2006  return diag;
2007  };
2008 
2009  switch (getPredicate()) {
2010  case MatchCmpIPredicate::eq:
2011  if (value.eq(refValue))
2012  break;
2013  return reportError("equal to");
2014  case MatchCmpIPredicate::ne:
2015  if (value.ne(refValue))
2016  break;
2017  return reportError("not equal to");
2018  case MatchCmpIPredicate::lt:
2019  if (value.slt(refValue))
2020  break;
2021  return reportError("less than");
2022  case MatchCmpIPredicate::le:
2023  if (value.sle(refValue))
2024  break;
2025  return reportError("less than or equal to");
2026  case MatchCmpIPredicate::gt:
2027  if (value.sgt(refValue))
2028  break;
2029  return reportError("greater than");
2030  case MatchCmpIPredicate::ge:
2031  if (value.sge(refValue))
2032  break;
2033  return reportError("greater than or equal to");
2034  }
2035  }
2037 }
2038 
2039 void transform::MatchParamCmpIOp::getEffects(
2040  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2041  onlyReadsHandle(getParamMutable(), effects);
2042  onlyReadsHandle(getReferenceMutable(), effects);
2043 }
2044 
2045 //===----------------------------------------------------------------------===//
2046 // ParamConstantOp
2047 //===----------------------------------------------------------------------===//
2048 
2050 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2051  transform::TransformResults &results,
2052  transform::TransformState &state) {
2053  results.setParams(cast<OpResult>(getParam()), {getValue()});
2055 }
2056 
2057 //===----------------------------------------------------------------------===//
2058 // MergeHandlesOp
2059 //===----------------------------------------------------------------------===//
2060 
2062 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2063  transform::TransformResults &results,
2064  transform::TransformState &state) {
2065  ValueRange handles = getHandles();
2066  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2067  SmallVector<Operation *> operations;
2068  for (Value operand : handles)
2069  llvm::append_range(operations, state.getPayloadOps(operand));
2070  if (!getDeduplicate()) {
2071  results.set(llvm::cast<OpResult>(getResult()), operations);
2073  }
2074 
2075  SetVector<Operation *> uniqued(operations.begin(), operations.end());
2076  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2078  }
2079 
2080  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2081  SmallVector<Attribute> attrs;
2082  for (Value attribute : handles)
2083  llvm::append_range(attrs, state.getParams(attribute));
2084  if (!getDeduplicate()) {
2085  results.setParams(cast<OpResult>(getResult()), attrs);
2087  }
2088 
2089  SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
2090  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2092  }
2093 
2094  assert(
2095  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2096  "expected value handle type");
2097  SmallVector<Value> payloadValues;
2098  for (Value value : handles)
2099  llvm::append_range(payloadValues, state.getPayloadValues(value));
2100  if (!getDeduplicate()) {
2101  results.setValues(cast<OpResult>(getResult()), payloadValues);
2103  }
2104 
2105  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
2106  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2108 }
2109 
2110 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2111  // Handles may be the same if deduplicating is enabled.
2112  return getDeduplicate();
2113 }
2114 
2115 void transform::MergeHandlesOp::getEffects(
2116  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2117  onlyReadsHandle(getHandlesMutable(), effects);
2118  producesHandle(getOperation()->getOpResults(), effects);
2119 
2120  // There are no effects on the Payload IR as this is only a handle
2121  // manipulation.
2122 }
2123 
2124 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2125  if (getDeduplicate() || getHandles().size() != 1)
2126  return {};
2127 
2128  // If deduplication is not required and there is only one operand, it can be
2129  // used directly instead of merging.
2130  return getHandles().front();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // NamedSequenceOp
2135 //===----------------------------------------------------------------------===//
2136 
2138 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2139  transform::TransformResults &results,
2140  transform::TransformState &state) {
2141  if (isExternal())
2142  return emitDefiniteFailure() << "unresolved external named sequence";
2143 
2144  // Map the entry block argument to the list of operations.
2145  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2146  // without attaching the interface / trait since that is tailored to a
2147  // dangling top-level op that does not get "called".
2148  auto scope = state.make_region_scope(getBody());
2150  state, this->getOperation(), getBody())))
2152 
2153  return applySequenceBlock(getBody().front(),
2154  FailurePropagationMode::Propagate, state, results);
2155 }
2156 
2157 void transform::NamedSequenceOp::getEffects(
2158  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2159 
2161  OperationState &result) {
2163  parser, result, /*allowVariadic=*/false,
2164  getFunctionTypeAttrName(result.name),
2165  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2167  std::string &) { return builder.getFunctionType(inputs, results); },
2168  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2169 }
2170 
2173  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2174  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2175  getResAttrsAttrName());
2176 }
2177 
2178 /// Verifies that a symbol function-like transform dialect operation has the
2179 /// signature and the terminator that have conforming types, i.e., types
2180 /// implementing the same transform dialect type interface. If `allowExternal`
2181 /// is set, allow external symbols (declarations) and don't check the terminator
2182 /// as it may not exist.
2184 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2185  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2188  << "cannot be defined inside another transform op";
2189  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2190  return diag;
2191  }
2192 
2193  if (op.isExternal() || op.getFunctionBody().empty()) {
2194  if (allowExternal)
2196 
2197  return emitSilenceableFailure(op) << "cannot be external";
2198  }
2199 
2200  if (op.getFunctionBody().front().empty())
2201  return emitSilenceableFailure(op) << "expected a non-empty body block";
2202 
2203  Operation *terminator = &op.getFunctionBody().front().back();
2204  if (!isa<transform::YieldOp>(terminator)) {
2206  << "expected '"
2207  << transform::YieldOp::getOperationName()
2208  << "' as terminator";
2209  diag.attachNote(terminator->getLoc()) << "terminator";
2210  return diag;
2211  }
2212 
2213  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2214  return emitSilenceableFailure(terminator)
2215  << "expected terminator to have as many operands as the parent op "
2216  "has results";
2217  }
2218  for (auto [i, operandType, resultType] : llvm::zip_equal(
2219  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2220  terminator->getOperands().getType(), op.getResultTypes())) {
2221  if (operandType == resultType)
2222  continue;
2223  return emitSilenceableFailure(terminator)
2224  << "the type of the terminator operand #" << i
2225  << " must match the type of the corresponding parent op result ("
2226  << operandType << " vs " << resultType << ")";
2227  }
2228 
2230 }
2231 
2232 /// Verification of a NamedSequenceOp. This does not report the error
2233 /// immediately, so it can be used to check for op's well-formedness before the
2234 /// verifier runs, e.g., during trait verification.
2236 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2237  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2238  if (!parent->getAttr(
2239  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2242  << "expects the parent symbol table to have the '"
2243  << transform::TransformDialect::kWithNamedSequenceAttrName
2244  << "' attribute";
2245  diag.attachNote(parent->getLoc()) << "symbol table operation";
2246  return diag;
2247  }
2248  }
2249 
2250  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2253  << "cannot be defined inside another transform op";
2254  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2255  return diag;
2256  }
2257 
2258  if (op.isExternal() || op.getBody().empty())
2259  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2260  emitWarnings);
2261 
2262  if (op.getBody().front().empty())
2263  return emitSilenceableFailure(op) << "expected a non-empty body block";
2264 
2265  Operation *terminator = &op.getBody().front().back();
2266  if (!isa<transform::YieldOp>(terminator)) {
2268  << "expected '"
2269  << transform::YieldOp::getOperationName()
2270  << "' as terminator";
2271  diag.attachNote(terminator->getLoc()) << "terminator";
2272  return diag;
2273  }
2274 
2275  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2276  return emitSilenceableFailure(terminator)
2277  << "expected terminator to have as many operands as the parent op "
2278  "has results";
2279  }
2280  for (auto [i, operandType, resultType] :
2281  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2282  terminator->getOperands().getType(),
2283  op.getFunctionType().getResults())) {
2284  if (operandType == resultType)
2285  continue;
2286  return emitSilenceableFailure(terminator)
2287  << "the type of the terminator operand #" << i
2288  << " must match the type of the corresponding parent op result ("
2289  << operandType << " vs " << resultType << ")";
2290  }
2291 
2292  auto funcOp = cast<FunctionOpInterface>(*op);
2294  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2295  if (!diag.succeeded())
2296  return diag;
2297 
2298  return verifyYieldingSingleBlockOp(funcOp,
2299  /*allowExternal=*/true);
2300 }
2301 
2302 LogicalResult transform::NamedSequenceOp::verify() {
2303  // Actual verification happens in a separate function for reusability.
2304  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2305 }
2306 
2307 template <typename FnTy>
2308 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2309  Type bbArgType, TypeRange extraBindingTypes,
2310  FnTy bodyBuilder) {
2311  SmallVector<Type> types;
2312  types.reserve(1 + extraBindingTypes.size());
2313  types.push_back(bbArgType);
2314  llvm::append_range(types, extraBindingTypes);
2315 
2316  OpBuilder::InsertionGuard guard(builder);
2317  Region *region = state.regions.back().get();
2318  Block *bodyBlock =
2319  builder.createBlock(region, region->begin(), types,
2320  SmallVector<Location>(types.size(), state.location));
2321 
2322  // Populate body.
2323  builder.setInsertionPointToStart(bodyBlock);
2324  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2325  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2326  } else {
2327  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2328  bodyBlock->getArguments().drop_front());
2329  }
2330 }
2331 
2332 void transform::NamedSequenceOp::build(OpBuilder &builder,
2333  OperationState &state, StringRef symName,
2334  Type rootType, TypeRange resultTypes,
2335  SequenceBodyBuilderFn bodyBuilder,
2337  ArrayRef<DictionaryAttr> argAttrs) {
2338  state.addAttribute(SymbolTable::getSymbolAttrName(),
2339  builder.getStringAttr(symName));
2340  state.addAttribute(getFunctionTypeAttrName(state.name),
2342  rootType, resultTypes)));
2343  state.attributes.append(attrs.begin(), attrs.end());
2344  state.addRegion();
2345 
2346  buildSequenceBody(builder, state, rootType,
2347  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2348 }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // NumAssociationsOp
2352 //===----------------------------------------------------------------------===//
2353 
2355 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2356  transform::TransformResults &results,
2357  transform::TransformState &state) {
2358  size_t numAssociations =
2360  .Case([&](TransformHandleTypeInterface opHandle) {
2361  return llvm::range_size(state.getPayloadOps(getHandle()));
2362  })
2363  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2364  return llvm::range_size(state.getPayloadValues(getHandle()));
2365  })
2366  .Case([&](TransformParamTypeInterface param) {
2367  return llvm::range_size(state.getParams(getHandle()));
2368  })
2369  .Default([](Type) {
2370  llvm_unreachable("unknown kind of transform dialect type");
2371  return 0;
2372  });
2373  results.setParams(cast<OpResult>(getNum()),
2374  rewriter.getI64IntegerAttr(numAssociations));
2376 }
2377 
2378 LogicalResult transform::NumAssociationsOp::verify() {
2379  // Verify that the result type accepts an i64 attribute as payload.
2380  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2381  return resultType
2382  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2383  .checkAndReport();
2384 }
2385 
2386 //===----------------------------------------------------------------------===//
2387 // SelectOp
2388 //===----------------------------------------------------------------------===//
2389 
2391 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2392  transform::TransformResults &results,
2393  transform::TransformState &state) {
2394  SmallVector<Operation *> result;
2395  auto payloadOps = state.getPayloadOps(getTarget());
2396  for (Operation *op : payloadOps) {
2397  if (op->getName().getStringRef() == getOpName())
2398  result.push_back(op);
2399  }
2400  results.set(cast<OpResult>(getResult()), result);
2402 }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // SplitHandleOp
2406 //===----------------------------------------------------------------------===//
2407 
2408 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2409  Value target, int64_t numResultHandles) {
2410  result.addOperands(target);
2411  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2412 }
2413 
2415 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2416  transform::TransformResults &results,
2417  transform::TransformState &state) {
2418  int64_t numPayloads =
2420  .Case<TransformHandleTypeInterface>([&](auto x) {
2421  return llvm::range_size(state.getPayloadOps(getHandle()));
2422  })
2423  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2424  return llvm::range_size(state.getPayloadValues(getHandle()));
2425  })
2426  .Case<TransformParamTypeInterface>([&](auto x) {
2427  return llvm::range_size(state.getParams(getHandle()));
2428  })
2429  .Default([](auto x) {
2430  llvm_unreachable("unknown transform dialect type interface");
2431  return -1;
2432  });
2433 
2434  auto produceNumOpsError = [&]() {
2435  return emitSilenceableError()
2436  << getHandle() << " expected to contain " << this->getNumResults()
2437  << " payloads but it contains " << numPayloads << " payloads";
2438  };
2439 
2440  // Fail if there are more payload ops than results and no overflow result was
2441  // specified.
2442  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2443  return produceNumOpsError();
2444 
2445  // Fail if there are more results than payload ops. Unless:
2446  // - "fail_on_payload_too_small" is set to "false", or
2447  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2448  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2449  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2450  return produceNumOpsError();
2451 
2452  // Distribute payloads.
2453  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2454  if (getOverflowResult())
2455  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2456 
2457  auto container = [&]() {
2458  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2459  return llvm::map_to_vector(
2460  state.getPayloadOps(getHandle()),
2461  [](Operation *op) -> MappedValue { return op; });
2462  }
2463  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2464  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2465  [](Value v) -> MappedValue { return v; });
2466  }
2467  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2468  "unsupported kind of transform dialect type");
2469  return llvm::map_to_vector(state.getParams(getHandle()),
2470  [](Attribute a) -> MappedValue { return a; });
2471  }();
2472 
2473  for (auto &&en : llvm::enumerate(container)) {
2474  int64_t resultNum = en.index();
2475  if (resultNum >= getNumResults())
2476  resultNum = *getOverflowResult();
2477  resultHandles[resultNum].push_back(en.value());
2478  }
2479 
2480  // Set transform op results.
2481  for (auto &&it : llvm::enumerate(resultHandles))
2482  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2483  it.value());
2484 
2486 }
2487 
2488 void transform::SplitHandleOp::getEffects(
2489  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2490  onlyReadsHandle(getHandleMutable(), effects);
2491  producesHandle(getOperation()->getOpResults(), effects);
2492  // There are no effects on the Payload IR as this is only a handle
2493  // manipulation.
2494 }
2495 
2496 LogicalResult transform::SplitHandleOp::verify() {
2497  if (getOverflowResult().has_value() &&
2498  !(*getOverflowResult() < getNumResults()))
2499  return emitOpError("overflow_result is not a valid result index");
2500 
2501  for (Type resultType : getResultTypes()) {
2502  if (implementSameTransformInterface(getHandle().getType(), resultType))
2503  continue;
2504 
2505  return emitOpError("expects result types to implement the same transform "
2506  "interface as the operand type");
2507  }
2508 
2509  return success();
2510 }
2511 
2512 //===----------------------------------------------------------------------===//
2513 // ReplicateOp
2514 //===----------------------------------------------------------------------===//
2515 
2517 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2518  transform::TransformResults &results,
2519  transform::TransformState &state) {
2520  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2521  for (const auto &en : llvm::enumerate(getHandles())) {
2522  Value handle = en.value();
2523  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2524  SmallVector<Operation *> current =
2525  llvm::to_vector(state.getPayloadOps(handle));
2526  SmallVector<Operation *> payload;
2527  payload.reserve(numRepetitions * current.size());
2528  for (unsigned i = 0; i < numRepetitions; ++i)
2529  llvm::append_range(payload, current);
2530  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2531  } else {
2532  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2533  "expected param type");
2534  ArrayRef<Attribute> current = state.getParams(handle);
2535  SmallVector<Attribute> params;
2536  params.reserve(numRepetitions * current.size());
2537  for (unsigned i = 0; i < numRepetitions; ++i)
2538  llvm::append_range(params, current);
2539  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2540  params);
2541  }
2542  }
2544 }
2545 
2546 void transform::ReplicateOp::getEffects(
2547  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2548  onlyReadsHandle(getPatternMutable(), effects);
2549  onlyReadsHandle(getHandlesMutable(), effects);
2550  producesHandle(getOperation()->getOpResults(), effects);
2551 }
2552 
2553 //===----------------------------------------------------------------------===//
2554 // SequenceOp
2555 //===----------------------------------------------------------------------===//
2556 
2558 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2559  transform::TransformResults &results,
2560  transform::TransformState &state) {
2561  // Map the entry block argument to the list of operations.
2562  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2563  if (failed(mapBlockArguments(state)))
2565 
2566  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2567  results);
2568 }
2569 
2570 static ParseResult parseSequenceOpOperands(
2571  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2572  Type &rootType,
2573  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2574  SmallVectorImpl<Type> &extraBindingTypes) {
2575  OpAsmParser::UnresolvedOperand rootOperand;
2576  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2577  if (!hasRoot.has_value()) {
2578  root = std::nullopt;
2579  return success();
2580  }
2581  if (failed(hasRoot.value()))
2582  return failure();
2583  root = rootOperand;
2584 
2585  if (succeeded(parser.parseOptionalComma())) {
2586  if (failed(parser.parseOperandList(extraBindings)))
2587  return failure();
2588  }
2589  if (failed(parser.parseColon()))
2590  return failure();
2591 
2592  // The paren is truly optional.
2593  (void)parser.parseOptionalLParen();
2594 
2595  if (failed(parser.parseType(rootType))) {
2596  return failure();
2597  }
2598 
2599  if (!extraBindings.empty()) {
2600  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2601  return failure();
2602  }
2603 
2604  if (extraBindingTypes.size() != extraBindings.size()) {
2605  return parser.emitError(parser.getNameLoc(),
2606  "expected types to be provided for all operands");
2607  }
2608 
2609  // The paren is truly optional.
2610  (void)parser.parseOptionalRParen();
2611  return success();
2612 }
2613 
2615  Value root, Type rootType,
2616  ValueRange extraBindings,
2617  TypeRange extraBindingTypes) {
2618  if (!root)
2619  return;
2620 
2621  printer << root;
2622  bool hasExtras = !extraBindings.empty();
2623  if (hasExtras) {
2624  printer << ", ";
2625  printer.printOperands(extraBindings);
2626  }
2627 
2628  printer << " : ";
2629  if (hasExtras)
2630  printer << "(";
2631 
2632  printer << rootType;
2633  if (hasExtras) {
2634  printer << ", ";
2635  llvm::interleaveComma(extraBindingTypes, printer.getStream());
2636  printer << ")";
2637  }
2638 }
2639 
2640 /// Returns `true` if the given op operand may be consuming the handle value in
2641 /// the Transform IR. That is, if it may have a Free effect on it.
2643  // Conservatively assume the effect being present in absence of the interface.
2644  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2645  if (!iface)
2646  return true;
2647 
2648  return isHandleConsumed(use.get(), iface);
2649 }
2650 
2651 LogicalResult
2653  function_ref<InFlightDiagnostic()> reportError) {
2654  OpOperand *potentialConsumer = nullptr;
2655  for (OpOperand &use : value.getUses()) {
2656  if (!isValueUsePotentialConsumer(use))
2657  continue;
2658 
2659  if (!potentialConsumer) {
2660  potentialConsumer = &use;
2661  continue;
2662  }
2663 
2664  InFlightDiagnostic diag = reportError()
2665  << " has more than one potential consumer";
2666  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2667  << "used here as operand #" << potentialConsumer->getOperandNumber();
2668  diag.attachNote(use.getOwner()->getLoc())
2669  << "used here as operand #" << use.getOperandNumber();
2670  return diag;
2671  }
2672 
2673  return success();
2674 }
2675 
2676 LogicalResult transform::SequenceOp::verify() {
2677  assert(getBodyBlock()->getNumArguments() >= 1 &&
2678  "the number of arguments must have been verified to be more than 1 by "
2679  "PossibleTopLevelTransformOpTrait");
2680 
2681  if (!getRoot() && !getExtraBindings().empty()) {
2682  return emitOpError()
2683  << "does not expect extra operands when used as top-level";
2684  }
2685 
2686  // Check if a block argument has more than one consuming use.
2687  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2688  if (failed(checkDoubleConsume(arg, [this, arg]() {
2689  return (emitOpError() << "block argument #" << arg.getArgNumber());
2690  }))) {
2691  return failure();
2692  }
2693  }
2694 
2695  // Check properties of the nested operations they cannot check themselves.
2696  for (Operation &child : *getBodyBlock()) {
2697  if (!isa<TransformOpInterface>(child) &&
2698  &child != &getBodyBlock()->back()) {
2700  emitOpError()
2701  << "expected children ops to implement TransformOpInterface";
2702  diag.attachNote(child.getLoc()) << "op without interface";
2703  return diag;
2704  }
2705 
2706  for (OpResult result : child.getResults()) {
2707  auto report = [&]() {
2708  return (child.emitError() << "result #" << result.getResultNumber());
2709  };
2710  if (failed(checkDoubleConsume(result, report)))
2711  return failure();
2712  }
2713  }
2714 
2715  if (!getBodyBlock()->mightHaveTerminator())
2716  return emitOpError() << "expects to have a terminator in the body";
2717 
2718  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2719  getOperation()->getResultTypes()) {
2720  InFlightDiagnostic diag = emitOpError()
2721  << "expects the types of the terminator operands "
2722  "to match the types of the result";
2723  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2724  return diag;
2725  }
2726  return success();
2727 }
2728 
2729 void transform::SequenceOp::getEffects(
2730  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2731  getPotentialTopLevelEffects(effects);
2732 }
2733 
2735 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2736  assert(point == getBody() && "unexpected region index");
2737  if (getOperation()->getNumOperands() > 0)
2738  return getOperation()->getOperands();
2739  return OperandRange(getOperation()->operand_end(),
2740  getOperation()->operand_end());
2741 }
2742 
2743 void transform::SequenceOp::getSuccessorRegions(
2744  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2745  if (point.isParent()) {
2746  Region *bodyRegion = &getBody();
2747  regions.emplace_back(bodyRegion, getNumOperands() != 0
2748  ? bodyRegion->getArguments()
2750  return;
2751  }
2752 
2753  assert(point == getBody() && "unexpected region index");
2754  regions.emplace_back(getOperation()->getResults());
2755 }
2756 
2757 void transform::SequenceOp::getRegionInvocationBounds(
2758  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2759  (void)operands;
2760  bounds.emplace_back(1, 1);
2761 }
2762 
2763 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2764  TypeRange resultTypes,
2765  FailurePropagationMode failurePropagationMode,
2766  Value root,
2767  SequenceBodyBuilderFn bodyBuilder) {
2768  build(builder, state, resultTypes, failurePropagationMode, root,
2769  /*extra_bindings=*/ValueRange());
2770  Type bbArgType = root.getType();
2771  buildSequenceBody(builder, state, bbArgType,
2772  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2773 }
2774 
2775 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2776  TypeRange resultTypes,
2777  FailurePropagationMode failurePropagationMode,
2778  Value root, ValueRange extraBindings,
2779  SequenceBodyBuilderArgsFn bodyBuilder) {
2780  build(builder, state, resultTypes, failurePropagationMode, root,
2781  extraBindings);
2782  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2783  bodyBuilder);
2784 }
2785 
2786 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2787  TypeRange resultTypes,
2788  FailurePropagationMode failurePropagationMode,
2789  Type bbArgType,
2790  SequenceBodyBuilderFn bodyBuilder) {
2791  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2792  /*extra_bindings=*/ValueRange());
2793  buildSequenceBody(builder, state, bbArgType,
2794  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2795 }
2796 
2797 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2798  TypeRange resultTypes,
2799  FailurePropagationMode failurePropagationMode,
2800  Type bbArgType, TypeRange extraBindingTypes,
2801  SequenceBodyBuilderArgsFn bodyBuilder) {
2802  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2803  /*extra_bindings=*/ValueRange());
2804  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2805 }
2806 
2807 //===----------------------------------------------------------------------===//
2808 // PrintOp
2809 //===----------------------------------------------------------------------===//
2810 
2811 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2812  StringRef name) {
2813  if (!name.empty())
2814  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2815 }
2816 
2817 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2818  Value target, StringRef name) {
2819  result.addOperands({target});
2820  build(builder, result, name);
2821 }
2822 
2824 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2825  transform::TransformResults &results,
2826  transform::TransformState &state) {
2827  llvm::outs() << "[[[ IR printer: ";
2828  if (getName().has_value())
2829  llvm::outs() << *getName() << " ";
2830 
2831  OpPrintingFlags printFlags;
2832  if (getAssumeVerified().value_or(false))
2833  printFlags.assumeVerified();
2834  if (getUseLocalScope().value_or(false))
2835  printFlags.useLocalScope();
2836  if (getSkipRegions().value_or(false))
2837  printFlags.skipRegions();
2838 
2839  if (!getTarget()) {
2840  llvm::outs() << "top-level ]]]\n";
2841  state.getTopLevel()->print(llvm::outs(), printFlags);
2842  llvm::outs() << "\n";
2844  }
2845 
2846  llvm::outs() << "]]]\n";
2847  for (Operation *target : state.getPayloadOps(getTarget())) {
2848  target->print(llvm::outs(), printFlags);
2849  llvm::outs() << "\n";
2850  }
2851 
2853 }
2854 
2855 void transform::PrintOp::getEffects(
2856  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2857  // We don't really care about mutability here, but `getTarget` now
2858  // unconditionally casts to a specific type before verification could run
2859  // here.
2860  if (!getTargetMutable().empty())
2861  onlyReadsHandle(getTargetMutable()[0], effects);
2862  onlyReadsPayload(effects);
2863 
2864  // There is no resource for stderr file descriptor, so just declare print
2865  // writes into the default resource.
2866  effects.emplace_back(MemoryEffects::Write::get());
2867 }
2868 
2869 //===----------------------------------------------------------------------===//
2870 // VerifyOp
2871 //===----------------------------------------------------------------------===//
2872 
2874 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2875  Operation *target,
2877  transform::TransformState &state) {
2878  if (failed(::mlir::verify(target))) {
2880  << "failed to verify payload op";
2881  diag.attachNote(target->getLoc()) << "payload op";
2882  return diag;
2883  }
2885 }
2886 
2887 void transform::VerifyOp::getEffects(
2888  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2889  transform::onlyReadsHandle(getTargetMutable(), effects);
2890 }
2891 
2892 //===----------------------------------------------------------------------===//
2893 // YieldOp
2894 //===----------------------------------------------------------------------===//
2895 
2896 void transform::YieldOp::getEffects(
2897  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2898  onlyReadsHandle(getOperandsMutable(), effects);
2899 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:287
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:295
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:281
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
type_range getTypes() const
Definition: ValueRange.cpp:26
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:231
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.