MLIR  22.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/Verifier.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Pass/PassRegistry.h"
30 #include "mlir/Transforms/CSE.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include <optional>
44 
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
47 
48 using namespace mlir;
49 
50 static ParseResult parseApplyRegisteredPassOptions(
51  OpAsmParser &parser, DictionaryAttr &options,
52  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
54  Operation *op,
55  DictionaryAttr options,
56  ValueRange dynamicOptions);
57 static ParseResult parseSequenceOpOperands(
58  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59  Type &rootType,
60  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61  SmallVectorImpl<Type> &extraBindingTypes);
62 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63  Value root, Type rootType,
64  ValueRange extraBindings,
65  TypeRange extraBindingTypes);
66 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
67  ArrayAttr matchers, ArrayAttr actions);
68 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
69  ArrayAttr &matchers,
70  ArrayAttr &actions);
71 
72 /// Helper function to check if the given transform op is contained in (or
73 /// equal to) the given payload target op. In that case, an error is returned.
74 /// Transforming transform IR that is currently executing is generally unsafe.
76 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77  Operation *payload) {
78  Operation *transformAncestor = transform.getOperation();
79  while (transformAncestor) {
80  if (transformAncestor == payload) {
82  transform.emitDefiniteFailure()
83  << "cannot apply transform to itself (or one of its ancestors)";
84  diag.attachNote(payload->getLoc()) << "target payload op";
85  return diag;
86  }
87  transformAncestor = transformAncestor->getParentOp();
88  }
90 }
91 
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94 
95 //===----------------------------------------------------------------------===//
96 // AlternativesOp
97 //===----------------------------------------------------------------------===//
98 
99 OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
100  RegionSuccessor successor) {
101  if (!successor.isParent() && getOperation()->getNumOperands() == 1)
102  return getOperation()->getOperands();
103  return OperandRange(getOperation()->operand_end(),
104  getOperation()->operand_end());
105 }
106 
107 void transform::AlternativesOp::getSuccessorRegions(
108  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
109  for (Region &alternative : llvm::drop_begin(
110  getAlternatives(), point.isParent()
111  ? 0
113  ->getParentRegion()
114  ->getRegionNumber() +
115  1)) {
116  regions.emplace_back(&alternative, !getOperands().empty()
117  ? alternative.getArguments()
119  }
120  if (!point.isParent())
121  regions.emplace_back(getOperation(), getOperation()->getResults());
122 }
123 
124 void transform::AlternativesOp::getRegionInvocationBounds(
125  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
126  (void)operands;
127  // The region corresponding to the first alternative is always executed, the
128  // remaining may or may not be executed.
129  bounds.reserve(getNumRegions());
130  bounds.emplace_back(1, 1);
131  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
132 }
133 
135  transform::TransformResults &results) {
136  for (const auto &res : block->getParentOp()->getOpResults())
137  results.set(res, {});
138 }
139 
141 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
143  transform::TransformState &state) {
144  SmallVector<Operation *> originals;
145  if (Value scopeHandle = getScope())
146  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
147  else
148  originals.push_back(state.getTopLevel());
149 
150  for (Operation *original : originals) {
151  if (original->isAncestor(getOperation())) {
152  auto diag = emitDefiniteFailure()
153  << "scope must not contain the transforms being applied";
154  diag.attachNote(original->getLoc()) << "scope";
155  return diag;
156  }
157  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
158  auto diag = emitDefiniteFailure()
159  << "only isolated-from-above ops can be alternative scopes";
160  diag.attachNote(original->getLoc()) << "scope";
161  return diag;
162  }
163  }
164 
165  for (Region &reg : getAlternatives()) {
166  // Clone the scope operations and make the transforms in this alternative
167  // region apply to them by virtue of mapping the block argument (the only
168  // visible handle) to the cloned scope operations. This effectively prevents
169  // the transformation from accessing any IR outside the scope.
170  auto scope = state.make_region_scope(reg);
171  auto clones = llvm::to_vector(
172  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
173  auto deleteClones = llvm::make_scope_exit([&] {
174  for (Operation *clone : clones)
175  clone->erase();
176  });
177  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
179 
180  bool failed = false;
181  for (Operation &transform : reg.front().without_terminator()) {
183  state.applyTransform(cast<TransformOpInterface>(transform));
184  if (result.isSilenceableFailure()) {
185  LDBG() << "alternative failed: " << result.getMessage();
186  failed = true;
187  break;
188  }
189 
190  if (::mlir::failed(result.silence()))
192  }
193 
194  // If all operations in the given alternative succeeded, no need to consider
195  // the rest. Replace the original scoping operation with the clone on which
196  // the transformations were performed.
197  if (!failed) {
198  // We will be using the clones, so cancel their scheduled deletion.
199  deleteClones.release();
200  TrackingListener listener(state, *this);
201  IRRewriter rewriter(getContext(), &listener);
202  for (const auto &kvp : llvm::zip(originals, clones)) {
203  Operation *original = std::get<0>(kvp);
204  Operation *clone = std::get<1>(kvp);
205  original->getBlock()->getOperations().insert(original->getIterator(),
206  clone);
207  rewriter.replaceOp(original, clone->getResults());
208  }
209  detail::forwardTerminatorOperands(&reg.front(), state, results);
211  }
212  }
213  return emitSilenceableError() << "all alternatives failed";
214 }
215 
216 void transform::AlternativesOp::getEffects(
217  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
218  consumesHandle(getOperation()->getOpOperands(), effects);
219  producesHandle(getOperation()->getOpResults(), effects);
220  for (Region *region : getRegions()) {
221  if (!region->empty())
222  producesHandle(region->front().getArguments(), effects);
223  }
224  modifiesPayload(effects);
225 }
226 
227 LogicalResult transform::AlternativesOp::verify() {
228  for (Region &alternative : getAlternatives()) {
229  Block &block = alternative.front();
230  Operation *terminator = block.getTerminator();
231  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
232  InFlightDiagnostic diag = emitOpError()
233  << "expects terminator operands to have the "
234  "same type as results of the operation";
235  diag.attachNote(terminator->getLoc()) << "terminator";
236  return diag;
237  }
238  }
239 
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // AnnotateOp
245 //===----------------------------------------------------------------------===//
246 
248 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
250  transform::TransformState &state) {
251  SmallVector<Operation *> targets =
252  llvm::to_vector(state.getPayloadOps(getTarget()));
253 
255  if (auto paramH = getParam()) {
256  ArrayRef<Attribute> params = state.getParams(paramH);
257  if (params.size() != 1) {
258  if (targets.size() != params.size()) {
259  return emitSilenceableError()
260  << "parameter and target have different payload lengths ("
261  << params.size() << " vs " << targets.size() << ")";
262  }
263  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
264  target->setAttr(getName(), attr);
266  }
267  attr = params[0];
268  }
269  for (auto *target : targets)
270  target->setAttr(getName(), attr);
272 }
273 
274 void transform::AnnotateOp::getEffects(
275  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
276  onlyReadsHandle(getTargetMutable(), effects);
277  onlyReadsHandle(getParamMutable(), effects);
278  modifiesPayload(effects);
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ApplyCommonSubexpressionEliminationOp
283 //===----------------------------------------------------------------------===//
284 
286 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
287  transform::TransformRewriter &rewriter, Operation *target,
288  ApplyToEachResultList &results, transform::TransformState &state) {
289  // Make sure that this transform is not applied to itself. Modifying the
290  // transform IR while it is being interpreted is generally dangerous.
291  DiagnosedSilenceableFailure payloadCheck =
293  if (!payloadCheck.succeeded())
294  return payloadCheck;
295 
296  DominanceInfo domInfo;
297  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
299 }
300 
301 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
302  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
303  transform::onlyReadsHandle(getTargetMutable(), effects);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // ApplyDeadCodeEliminationOp
309 //===----------------------------------------------------------------------===//
310 
311 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
312  transform::TransformRewriter &rewriter, Operation *target,
313  ApplyToEachResultList &results, transform::TransformState &state) {
314  // Make sure that this transform is not applied to itself. Modifying the
315  // transform IR while it is being interpreted is generally dangerous.
316  DiagnosedSilenceableFailure payloadCheck =
318  if (!payloadCheck.succeeded())
319  return payloadCheck;
320 
321  // Maintain a worklist of potentially dead ops.
322  SetVector<Operation *> worklist;
323 
324  // Helper function that adds all defining ops of used values (operands and
325  // operands of nested ops).
326  auto addDefiningOpsToWorklist = [&](Operation *op) {
327  op->walk([&](Operation *op) {
328  for (Value v : op->getOperands())
329  if (Operation *defOp = v.getDefiningOp())
330  if (target->isProperAncestor(defOp))
331  worklist.insert(defOp);
332  });
333  };
334 
335  // Helper function that erases an op.
336  auto eraseOp = [&](Operation *op) {
337  // Remove op and nested ops from the worklist.
338  op->walk([&](Operation *op) {
339  const auto *it = llvm::find(worklist, op);
340  if (it != worklist.end())
341  worklist.erase(it);
342  });
343  rewriter.eraseOp(op);
344  };
345 
346  // Initial walk over the IR.
347  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
348  if (op != target && isOpTriviallyDead(op)) {
349  addDefiningOpsToWorklist(op);
350  eraseOp(op);
351  }
352  });
353 
354  // Erase all ops that have become dead.
355  while (!worklist.empty()) {
356  Operation *op = worklist.pop_back_val();
357  if (!isOpTriviallyDead(op))
358  continue;
359  addDefiningOpsToWorklist(op);
360  eraseOp(op);
361  }
362 
364 }
365 
366 void transform::ApplyDeadCodeEliminationOp::getEffects(
367  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
368  transform::onlyReadsHandle(getTargetMutable(), effects);
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // ApplyPatternsOp
374 //===----------------------------------------------------------------------===//
375 
376 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
377  transform::TransformRewriter &rewriter, Operation *target,
378  ApplyToEachResultList &results, transform::TransformState &state) {
379  // Make sure that this transform is not applied to itself. Modifying the
380  // transform IR while it is being interpreted is generally dangerous. Even
381  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
382  // performs many additional simplifications such as dead code elimination.
383  DiagnosedSilenceableFailure payloadCheck =
385  if (!payloadCheck.succeeded())
386  return payloadCheck;
387 
388  // Gather all specified patterns.
389  MLIRContext *ctx = target->getContext();
391  if (!getRegion().empty()) {
392  for (Operation &op : getRegion().front()) {
393  cast<transform::PatternDescriptorOpInterface>(&op)
394  .populatePatternsWithState(patterns, state);
395  }
396  }
397 
398  // Configure the GreedyPatternRewriteDriver.
400  config.setListener(
401  static_cast<RewriterBase::Listener *>(rewriter.getListener()));
402  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
403 
404  config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
406  : getMaxIterations());
407  config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
409  : getMaxNumRewrites());
410 
411  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
412  // was requested, apply the greedy pattern rewrite only once. (The greedy
413  // pattern rewrite driver already iterates to a fixpoint internally.)
414  bool cseChanged = false;
415  // One or two iterations should be sufficient. Stop iterating after a certain
416  // threshold to make debugging easier.
417  static const int64_t kNumMaxIterations = 50;
418  int64_t iteration = 0;
419  do {
420  LogicalResult result = failure();
421  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
422  // Op is isolated from above. Apply patterns and also perform region
423  // simplification.
424  result = applyPatternsGreedily(target, frozenPatterns, config);
425  } else {
426  // Manually gather list of ops because the other
427  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
428  // from above. This way, patterns can be applied to ops that are not
429  // isolated from above. Regions are not being simplified. Furthermore,
430  // only a single greedy rewrite iteration is performed.
432  target->walk([&](Operation *nestedOp) {
433  if (target != nestedOp)
434  ops.push_back(nestedOp);
435  });
436  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
437  }
438 
439  // A failure typically indicates that the pattern application did not
440  // converge.
441  if (failed(result)) {
442  return emitSilenceableFailure(target)
443  << "greedy pattern application failed";
444  }
445 
446  if (getApplyCse()) {
447  DominanceInfo domInfo;
448  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
449  &cseChanged);
450  }
451  } while (cseChanged && ++iteration < kNumMaxIterations);
452 
453  if (iteration == kNumMaxIterations)
454  return emitDefiniteFailure() << "fixpoint iteration did not converge";
455 
457 }
458 
459 LogicalResult transform::ApplyPatternsOp::verify() {
460  if (!getRegion().empty()) {
461  for (Operation &op : getRegion().front()) {
462  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
463  InFlightDiagnostic diag = emitOpError()
464  << "expected children ops to implement "
465  "PatternDescriptorOpInterface";
466  diag.attachNote(op.getLoc()) << "op without interface";
467  return diag;
468  }
469  }
470  }
471  return success();
472 }
473 
474 void transform::ApplyPatternsOp::getEffects(
475  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
476  transform::onlyReadsHandle(getTargetMutable(), effects);
478 }
479 
480 void transform::ApplyPatternsOp::build(
481  OpBuilder &builder, OperationState &result, Value target,
482  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
483  result.addOperands(target);
484 
485  OpBuilder::InsertionGuard g(builder);
486  Region *region = result.addRegion();
487  builder.createBlock(region);
488  if (bodyBuilder)
489  bodyBuilder(builder, result.location);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // ApplyCanonicalizationPatternsOp
494 //===----------------------------------------------------------------------===//
495 
496 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
498  MLIRContext *ctx = patterns.getContext();
499  for (Dialect *dialect : ctx->getLoadedDialects())
500  dialect->getCanonicalizationPatterns(patterns);
502  op.getCanonicalizationPatterns(patterns, ctx);
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // ApplyConversionPatternsOp
507 //===----------------------------------------------------------------------===//
508 
509 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
512  MLIRContext *ctx = getContext();
513 
514  // Instantiate the default type converter if a type converter builder is
515  // specified.
516  std::unique_ptr<TypeConverter> defaultTypeConverter;
517  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
518  getDefaultTypeConverter();
519  if (typeConverterBuilder)
520  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
521 
522  // Configure conversion target.
523  ConversionTarget conversionTarget(*getContext());
524  if (getLegalOps())
525  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
526  conversionTarget.addLegalOp(
527  OperationName(cast<StringAttr>(attr).getValue(), ctx));
528  if (getIllegalOps())
529  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
530  conversionTarget.addIllegalOp(
531  OperationName(cast<StringAttr>(attr).getValue(), ctx));
532  if (getLegalDialects())
533  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
534  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
535  if (getIllegalDialects())
536  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
537  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
538 
539  // Gather all specified patterns.
541  // Need to keep the converters alive until after pattern application because
542  // the patterns take a reference to an object that would otherwise get out of
543  // scope.
544  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
545  if (!getPatterns().empty()) {
546  for (Operation &op : getPatterns().front()) {
547  auto descriptor =
548  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
549 
550  // Check if this pattern set specifies a type converter.
551  std::unique_ptr<TypeConverter> typeConverter =
552  descriptor.getTypeConverter();
553  TypeConverter *converter = nullptr;
554  if (typeConverter) {
555  keepAliveConverters.emplace_back(std::move(typeConverter));
556  converter = keepAliveConverters.back().get();
557  } else {
558  // No type converter specified: Use the default type converter.
559  if (!defaultTypeConverter) {
560  auto diag = emitDefiniteFailure()
561  << "pattern descriptor does not specify type "
562  "converter and apply_conversion_patterns op has "
563  "no default type converter";
564  diag.attachNote(op.getLoc()) << "pattern descriptor op";
565  return diag;
566  }
567  converter = defaultTypeConverter.get();
568  }
569 
570  // Add descriptor-specific updates to the conversion target, which may
571  // depend on the final type converter. In structural converters, the
572  // legality of types dictates the dynamic legality of an operation.
573  descriptor.populateConversionTargetRules(*converter, conversionTarget);
574 
575  descriptor.populatePatterns(*converter, patterns);
576  }
577  }
578 
579  // Attach a tracking listener if handles should be preserved. We configure the
580  // listener to allow op replacements with different names, as conversion
581  // patterns typically replace ops with replacement ops that have a different
582  // name.
583  TrackingListenerConfig trackingConfig;
584  trackingConfig.requireMatchingReplacementOpName = false;
585  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
586  ConversionConfig conversionConfig;
587  if (getPreserveHandles())
588  conversionConfig.listener = &trackingListener;
589 
590  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
591  for (Operation *target : state.getPayloadOps(getTarget())) {
592  // Make sure that this transform is not applied to itself. Modifying the
593  // transform IR while it is being interpreted is generally dangerous.
594  DiagnosedSilenceableFailure payloadCheck =
596  if (!payloadCheck.succeeded())
597  return payloadCheck;
598 
599  LogicalResult status = failure();
600  if (getPartialConversion()) {
601  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
602  conversionConfig);
603  } else {
604  status = applyFullConversion(target, conversionTarget, frozenPatterns,
605  conversionConfig);
606  }
607 
608  // Check dialect conversion state.
610  if (failed(status)) {
611  diag = emitSilenceableError() << "dialect conversion failed";
612  diag.attachNote(target->getLoc()) << "target op";
613  }
614 
615  // Check tracking listener error state.
616  DiagnosedSilenceableFailure trackingFailure =
617  trackingListener.checkAndResetError();
618  if (!trackingFailure.succeeded()) {
619  if (diag.succeeded()) {
620  // Tracking failure is the only failure.
621  return trackingFailure;
622  }
623  diag.attachNote() << "tracking listener also failed: "
624  << trackingFailure.getMessage();
625  (void)trackingFailure.silence();
626  }
627 
628  if (!diag.succeeded())
629  return diag;
630  }
631 
633 }
634 
636  if (getNumRegions() != 1 && getNumRegions() != 2)
637  return emitOpError() << "expected 1 or 2 regions";
638  if (!getPatterns().empty()) {
639  for (Operation &op : getPatterns().front()) {
640  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
642  emitOpError() << "expected pattern children ops to implement "
643  "ConversionPatternDescriptorOpInterface";
644  diag.attachNote(op.getLoc()) << "op without interface";
645  return diag;
646  }
647  }
648  }
649  if (getNumRegions() == 2) {
650  Region &typeConverterRegion = getRegion(1);
651  if (!llvm::hasSingleElement(typeConverterRegion.front()))
652  return emitOpError()
653  << "expected exactly one op in default type converter region";
654  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
655  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
656  maybeTypeConverter);
657  if (!typeConverterOp) {
658  InFlightDiagnostic diag = emitOpError()
659  << "expected default converter child op to "
660  "implement TypeConverterBuilderOpInterface";
661  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
662  return diag;
663  }
664  // Check default type converter type.
665  if (!getPatterns().empty()) {
666  for (Operation &op : getPatterns().front()) {
667  auto descriptor =
668  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
669  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
670  return failure();
671  }
672  }
673  }
674  return success();
675 }
676 
677 void transform::ApplyConversionPatternsOp::getEffects(
678  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
679  if (!getPreserveHandles()) {
680  transform::consumesHandle(getTargetMutable(), effects);
681  } else {
682  transform::onlyReadsHandle(getTargetMutable(), effects);
683  }
685 }
686 
687 void transform::ApplyConversionPatternsOp::build(
688  OpBuilder &builder, OperationState &result, Value target,
689  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
690  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
691  result.addOperands(target);
692 
693  {
694  OpBuilder::InsertionGuard g(builder);
695  Region *region1 = result.addRegion();
696  builder.createBlock(region1);
697  if (patternsBodyBuilder)
698  patternsBodyBuilder(builder, result.location);
699  }
700  {
701  OpBuilder::InsertionGuard g(builder);
702  Region *region2 = result.addRegion();
703  builder.createBlock(region2);
704  if (typeConverterBodyBuilder)
705  typeConverterBodyBuilder(builder, result.location);
706  }
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // ApplyToLLVMConversionPatternsOp
711 //===----------------------------------------------------------------------===//
712 
713 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
714  TypeConverter &typeConverter, RewritePatternSet &patterns) {
715  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
716  assert(dialect && "expected that dialect is loaded");
717  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718  // ConversionTarget is currently ignored because the enclosing
719  // apply_conversion_patterns op sets up its own ConversionTarget.
720  ConversionTarget target(*getContext());
721  iface->populateConvertToLLVMConversionPatterns(
722  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
723 }
724 
725 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
726  transform::TypeConverterBuilderOpInterface builder) {
727  if (builder.getTypeConverterType() != "LLVMTypeConverter")
728  return emitOpError("expected LLVMTypeConverter");
729  return success();
730 }
731 
733  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
734  if (!dialect)
735  return emitOpError("unknown dialect or dialect not loaded: ")
736  << getDialectName();
737  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
738  if (!iface)
739  return emitOpError(
740  "dialect does not implement ConvertToLLVMPatternInterface or "
741  "extension was not loaded: ")
742  << getDialectName();
743  return success();
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // ApplyLoopInvariantCodeMotionOp
748 //===----------------------------------------------------------------------===//
749 
751 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
752  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
754  transform::TransformState &state) {
755  // Currently, LICM does not remove operations, so we don't need tracking.
756  // If this ever changes, add a LICM entry point that takes a rewriter.
757  moveLoopInvariantCode(target);
759 }
760 
761 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
762  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
763  transform::onlyReadsHandle(getTargetMutable(), effects);
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // ApplyRegisteredPassOp
769 //===----------------------------------------------------------------------===//
770 
771 void transform::ApplyRegisteredPassOp::getEffects(
772  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
773  consumesHandle(getTargetMutable(), effects);
774  onlyReadsHandle(getDynamicOptionsMutable(), effects);
775  producesHandle(getOperation()->getOpResults(), effects);
776  modifiesPayload(effects);
777 }
778 
780 transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
782  transform::TransformState &state) {
783  // Obtain a single options-string to pass to the pass(-pipeline) from options
784  // passed in as a dictionary of keys mapping to values which are either
785  // attributes or param-operands pointing to attributes.
786  OperandRange dynamicOptions = getDynamicOptions();
787 
788  std::string options;
789  llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
790 
791  // A helper to convert an option's attribute value into a corresponding
792  // string representation, with the ability to obtain the attr(s) from a param.
793  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
794  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
795  // The corresponding value attribute(s) is/are passed in via a param.
796  // Obtain the param-operand via its specified index.
797  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
798  assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
799  "the number of ParamOperandAttrs in the options DictionaryAttr"
800  "should be the same as the number of options passed as params");
801  ArrayRef<Attribute> attrsAssociatedToParam =
802  state.getParams(dynamicOptions[dynamicOptionIdx]);
803  // Recursive so as to append all attrs associated to the param.
804  llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
805  ",");
806  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
807  // Recursive so as to append all nested attrs of the array.
808  llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
809  } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
810  // Convert to unquoted string.
811  optionsStream << strAttr.getValue().str();
812  } else {
813  // For all other attributes, ask the attr to print itself (without type).
814  valueAttr.print(optionsStream, /*elideType=*/true);
815  }
816  };
817 
818  // Convert the options DictionaryAttr into a single string.
819  llvm::interleave(
820  getOptions(), optionsStream,
821  [&](auto namedAttribute) {
822  optionsStream << namedAttribute.getName().str(); // Append the key.
823  optionsStream << "="; // And the key-value separator.
824  appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
825  },
826  " ");
827  optionsStream.flush();
828 
829  // Get pass or pass pipeline from registry.
830  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
831  if (!info)
832  info = PassInfo::lookup(getPassName());
833  if (!info)
834  return emitDefiniteFailure()
835  << "unknown pass or pass pipeline: " << getPassName();
836 
837  // Create pass manager and add the pass or pass pipeline.
838  PassManager pm(getContext());
839  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
840  emitError(msg);
841  return failure();
842  }))) {
843  return emitDefiniteFailure()
844  << "failed to add pass or pass pipeline to pipeline: "
845  << getPassName();
846  }
847 
848  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
849  for (Operation *target : targets) {
850  // Make sure that this transform is not applied to itself. Modifying the
851  // transform IR while it is being interpreted is generally dangerous. Even
852  // more so when applying passes because they may perform a wide range of IR
853  // modifications.
854  DiagnosedSilenceableFailure payloadCheck =
856  if (!payloadCheck.succeeded())
857  return payloadCheck;
858 
859  // Run the pass or pass pipeline on the current target operation.
860  if (failed(pm.run(target))) {
861  auto diag = emitSilenceableError() << "pass pipeline failed";
862  diag.attachNote(target->getLoc()) << "target op";
863  return diag;
864  }
865  }
866 
867  // The applied pass will have directly modified the payload IR(s).
868  results.set(llvm::cast<OpResult>(getResult()), targets);
870 }
871 
873  OpAsmParser &parser, DictionaryAttr &options,
874  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
875  // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
876  SmallVector<NamedAttribute> keyValuePairs;
877  size_t dynamicOptionsIdx = 0;
878 
879  // Helper for allowing parsing of option values which can be of the form:
880  // - a normal attribute
881  // - an operand (which would be converted to an attr referring to the operand)
882  // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
883  std::function<ParseResult(Attribute &)> parseValue =
884  [&](Attribute &valueAttr) -> ParseResult {
885  // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
886  if (succeeded(parser.parseOptionalLSquare())) {
888 
889  // Recursively parse the array's elements, which might be operands.
890  if (parser.parseCommaSeparatedList(
892  [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
893  " in options dictionary") ||
894  parser.parseRSquare())
895  return failure(); // NB: Attempted parse should've output error message.
896 
897  valueAttr = ArrayAttr::get(parser.getContext(), attrs);
898 
899  return success();
900  }
901 
902  // Parse the value, which can be either an attribute or an operand.
903  OptionalParseResult parsedValueAttr =
904  parser.parseOptionalAttribute(valueAttr);
905  if (!parsedValueAttr.has_value()) {
907  ParseResult parsedOperand = parser.parseOperand(operand);
908  if (failed(parsedOperand))
909  return failure(); // NB: Attempted parse should've output error message.
910  // To make use of the operand, we need to store it in the options dict.
911  // As SSA-values cannot occur in attributes, what we do instead is store
912  // an attribute in its place that contains the index of the param-operand,
913  // so that an attr-value associated to the param can be resolved later on.
914  dynamicOptions.push_back(operand);
915  auto wrappedIndex = IntegerAttr::get(
916  IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
917  valueAttr =
918  transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
919  } else if (failed(parsedValueAttr.value())) {
920  return failure(); // NB: Attempted parse should have output error message.
921  } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
922  return parser.emitError(parser.getCurrentLocation())
923  << "the param_operand attribute is a marker reserved for "
924  << "indicating a value will be passed via params and is only used "
925  << "in the generic print format";
926  }
927 
928  return success();
929  };
930 
931  // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
932  // string and `value` looks like either an attribute or an operand-in-an-attr.
933  std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
934  std::string key;
935  Attribute valueAttr;
936 
937  if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
938  return parser.emitError(parser.getCurrentLocation())
939  << "expected key to either be an identifier or a string";
940 
941  if (failed(parser.parseEqual()))
942  return parser.emitError(parser.getCurrentLocation())
943  << "expected '=' after key in key-value pair";
944 
945  if (failed(parseValue(valueAttr)))
946  return parser.emitError(parser.getCurrentLocation())
947  << "expected a valid attribute or operand as value associated "
948  << "to key '" << key << "'";
949 
950  keyValuePairs.push_back(NamedAttribute(key, valueAttr));
951 
952  return success();
953  };
954 
957  " in options dictionary"))
958  return failure(); // NB: Attempted parse should have output error message.
959 
960  if (DictionaryAttr::findDuplicate(
961  keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
962  .has_value())
963  return parser.emitError(parser.getCurrentLocation())
964  << "duplicate keys found in options dictionary";
965 
966  options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
967 
968  return success();
969 }
970 
972  Operation *op,
973  DictionaryAttr options,
974  ValueRange dynamicOptions) {
975  if (options.empty())
976  return;
977 
978  std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
979  if (auto paramOperandAttr =
980  dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
981  // Resolve index of param-operand to its actual SSA-value and print that.
982  printer.printOperand(
983  dynamicOptions[paramOperandAttr.getIndex().getInt()]);
984  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
985  // This case is so that ArrayAttr-contained operands are pretty-printed.
986  printer << "[";
987  llvm::interleaveComma(arrayAttr, printer, printOptionValue);
988  printer << "]";
989  } else {
990  printer.printAttribute(valueAttr);
991  }
992  };
993 
994  printer << "{";
995  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
996  printer << namedAttribute.getName();
997  printer << " = ";
998  printOptionValue(namedAttribute.getValue());
999  });
1000  printer << "}";
1001 }
1002 
1004  // Check that there is a one-to-one correspondence between param operands
1005  // and references to dynamic options in the options dictionary.
1006 
1007  auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1008 
1009  // Helper for option values to mark seen operands as having been seen (once).
1010  std::function<LogicalResult(Attribute)> checkOptionValue =
1011  [&](Attribute valueAttr) -> LogicalResult {
1012  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1013  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1014  if (dynamicOptionIdx < 0 ||
1015  dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1016  return emitOpError()
1017  << "dynamic option index " << dynamicOptionIdx
1018  << " is out of bounds for the number of dynamic options: "
1019  << dynamicOptions.size();
1020  if (dynamicOptions[dynamicOptionIdx] == nullptr)
1021  return emitOpError() << "dynamic option index " << dynamicOptionIdx
1022  << " is already used in options";
1023  dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1024  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1025  // Recurse into ArrayAttrs as they may contain references to operands.
1026  for (auto eltAttr : arrayAttr)
1027  if (failed(checkOptionValue(eltAttr)))
1028  return failure();
1029  }
1030  return success();
1031  };
1032 
1033  for (NamedAttribute namedAttr : getOptions())
1034  if (failed(checkOptionValue(namedAttr.getValue())))
1035  return failure();
1036 
1037  // All dynamicOptions-params seen in the dict will have been set to null.
1038  for (Value dynamicOption : dynamicOptions)
1039  if (dynamicOption)
1040  return emitOpError() << "a param operand does not have a corresponding "
1041  << "param_operand attr in the options dict";
1042 
1043  return success();
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // CastOp
1048 //===----------------------------------------------------------------------===//
1049 
1051 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1052  Operation *target, ApplyToEachResultList &results,
1053  transform::TransformState &state) {
1054  results.push_back(target);
1056 }
1057 
1058 void transform::CastOp::getEffects(
1059  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1060  onlyReadsPayload(effects);
1061  onlyReadsHandle(getInputMutable(), effects);
1062  producesHandle(getOperation()->getOpResults(), effects);
1063 }
1064 
1065 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1066  assert(inputs.size() == 1 && "expected one input");
1067  assert(outputs.size() == 1 && "expected one output");
1068  return llvm::all_of(
1069  std::initializer_list<Type>{inputs.front(), outputs.front()},
1070  llvm::IsaPred<transform::TransformHandleTypeInterface>);
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // CollectMatchingOp
1075 //===----------------------------------------------------------------------===//
1076 
1077 /// Applies matcher operations from the given `block` using
1078 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
1079 /// accordingly. If any of the matcher produces a silenceable failure, discards
1080 /// it (printing the content to the debug output stream) and returns failure. If
1081 /// any of the matchers produces a definite failure, reports it and returns
1082 /// failure. If all matchers in the block succeed, populates `mappings` with the
1083 /// payload entities associated with the block terminator operands. Note that
1084 /// `mappings` will be cleared before that.
1087  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1089  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1090  assert(block.getParent() && "cannot match using a detached block");
1091  auto matchScope = state.make_region_scope(*block.getParent());
1092  if (failed(
1093  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1095 
1096  for (Operation &match : block.without_terminator()) {
1097  if (!isa<transform::MatchOpInterface>(match)) {
1098  return emitDefiniteFailure(match.getLoc())
1099  << "expected operations in the match part to "
1100  "implement MatchOpInterface";
1101  }
1103  state.applyTransform(cast<transform::TransformOpInterface>(match));
1104  if (diag.succeeded())
1105  continue;
1106 
1107  return diag;
1108  }
1109 
1110  // Remember the values mapped to the terminator operands so we can
1111  // forward them to the action.
1112  ValueRange yieldedValues = block.getTerminator()->getOperands();
1113  // Our contract with the caller is that the mappings will contain only the
1114  // newly mapped values, clear the rest.
1115  mappings.clear();
1116  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1118 }
1119 
1120 /// Returns `true` if both types implement one of the interfaces provided as
1121 /// template parameters.
1122 template <typename... Tys>
1123 static bool implementSameInterface(Type t1, Type t2) {
1124  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1125 }
1126 
1127 /// Returns `true` if both types implement one of the transform dialect
1128 /// interfaces.
1130  return implementSameInterface<transform::TransformHandleTypeInterface,
1131  transform::TransformParamTypeInterface,
1132  transform::TransformValueHandleTypeInterface>(
1133  t1, t2);
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // CollectMatchingOp
1138 //===----------------------------------------------------------------------===//
1139 
1141 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1142  transform::TransformResults &results,
1143  transform::TransformState &state) {
1144  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1145  getOperation(), getMatcher());
1146  if (matcher.isExternal()) {
1147  return emitDefiniteFailure()
1148  << "unresolved external symbol " << getMatcher();
1149  }
1150 
1151  SmallVector<SmallVector<MappedValue>, 2> rawResults;
1152  rawResults.resize(getOperation()->getNumResults());
1153  std::optional<DiagnosedSilenceableFailure> maybeFailure;
1154  for (Operation *root : state.getPayloadOps(getRoot())) {
1155  WalkResult walkResult = root->walk([&](Operation *op) {
1156  LDBG(DEBUG_TYPE_MATCHER, 1)
1157  << "matching "
1158  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1159  << " @" << op;
1160 
1161  // Try matching.
1163  SmallVector<transform::MappedValue> inputMapping({op});
1165  matcher.getFunctionBody().front(),
1166  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1167  mappings);
1168  if (diag.isDefiniteFailure())
1169  return WalkResult::interrupt();
1170  if (diag.isSilenceableFailure()) {
1171  LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1172  << " failed: " << diag.getMessage();
1173  return WalkResult::advance();
1174  }
1175 
1176  // If succeeded, collect results.
1177  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1178  if (mapping.size() != 1) {
1179  maybeFailure.emplace(emitSilenceableError()
1180  << "result #" << i << ", associated with "
1181  << mapping.size()
1182  << " payload objects, expected 1");
1183  return WalkResult::interrupt();
1184  }
1185  rawResults[i].push_back(mapping[0]);
1186  }
1187  return WalkResult::advance();
1188  });
1189  if (walkResult.wasInterrupted())
1190  return std::move(*maybeFailure);
1191  assert(!maybeFailure && "failure set but the walk was not interrupted");
1192 
1193  for (auto &&[opResult, rawResult] :
1194  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1195  results.setMappedValues(opResult, rawResult);
1196  }
1197  }
1199 }
1200 
1201 void transform::CollectMatchingOp::getEffects(
1202  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1203  onlyReadsHandle(getRootMutable(), effects);
1204  producesHandle(getOperation()->getOpResults(), effects);
1205  onlyReadsPayload(effects);
1206 }
1207 
1208 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1209  SymbolTableCollection &symbolTable) {
1210  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1211  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1212  if (!matcherSymbol ||
1213  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1214  return emitError() << "unresolved matcher symbol " << getMatcher();
1215 
1216  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1217  if (argumentTypes.size() != 1 ||
1218  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1219  return emitError()
1220  << "expected the matcher to take one operation handle argument";
1221  }
1222  if (!matcherSymbol.getArgAttr(
1223  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1224  return emitError() << "expected the matcher argument to be marked readonly";
1225  }
1226 
1227  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1228  if (resultTypes.size() != getOperation()->getNumResults()) {
1229  return emitError()
1230  << "expected the matcher to yield as many values as op has results ("
1231  << getOperation()->getNumResults() << "), got "
1232  << resultTypes.size();
1233  }
1234 
1235  for (auto &&[i, matcherType, resultType] :
1236  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1237  if (implementSameTransformInterface(matcherType, resultType))
1238  continue;
1239 
1240  return emitError()
1241  << "mismatching type interfaces for matcher result and op result #"
1242  << i;
1243  }
1244 
1245  return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // ForeachMatchOp
1250 //===----------------------------------------------------------------------===//
1251 
1252 // This is fine because nothing is actually consumed by this op.
1253 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1254 
1256 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1257  transform::TransformResults &results,
1258  transform::TransformState &state) {
1260  matchActionPairs;
1261  matchActionPairs.reserve(getMatchers().size());
1262  SymbolTableCollection symbolTable;
1263  for (auto &&[matcher, action] :
1264  llvm::zip_equal(getMatchers(), getActions())) {
1265  auto matcherSymbol =
1266  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1267  getOperation(), cast<SymbolRefAttr>(matcher));
1268  auto actionSymbol =
1269  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1270  getOperation(), cast<SymbolRefAttr>(action));
1271  assert(matcherSymbol && actionSymbol &&
1272  "unresolved symbols not caught by the verifier");
1273 
1274  if (matcherSymbol.isExternal())
1275  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1276  if (actionSymbol.isExternal())
1277  return emitDefiniteFailure() << "unresolved external symbol " << action;
1278 
1279  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1280  }
1281 
1282  DiagnosedSilenceableFailure overallDiag =
1284 
1285  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1286  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1287  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1288  // Explicitly add the mapping for the first block argument (the op being
1289  // matched).
1290  matchInputMapping.emplace_back();
1291  transform::detail::prepareValueMappings(matchInputMapping,
1292  getForwardedInputs(), state);
1293  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1294  actionResultMapping.resize(getForwardedOutputs().size());
1295 
1296  for (Operation *root : state.getPayloadOps(getRoot())) {
1297  WalkResult walkResult = root->walk([&](Operation *op) {
1298  // If getRestrictRoot is not present, skip over the root op itself so we
1299  // don't invalidate it.
1300  if (!getRestrictRoot() && op == root)
1301  return WalkResult::advance();
1302 
1303  LDBG(DEBUG_TYPE_MATCHER, 1)
1304  << "matching "
1305  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1306  << " @" << op;
1307 
1308  firstMatchArgument.clear();
1309  firstMatchArgument.push_back(op);
1310 
1311  // Try all the match/action pairs until the first successful match.
1312  for (auto [matcher, action] : matchActionPairs) {
1314  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1315  state, matchOutputMapping);
1316  if (diag.isDefiniteFailure())
1317  return WalkResult::interrupt();
1318  if (diag.isSilenceableFailure()) {
1319  LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1320  << " failed: " << diag.getMessage();
1321  continue;
1322  }
1323 
1324  auto scope = state.make_region_scope(action.getFunctionBody());
1325  if (failed(state.mapBlockArguments(
1326  action.getFunctionBody().front().getArguments(),
1327  matchOutputMapping))) {
1328  return WalkResult::interrupt();
1329  }
1330 
1331  for (Operation &transform :
1332  action.getFunctionBody().front().without_terminator()) {
1334  state.applyTransform(cast<TransformOpInterface>(transform));
1335  if (result.isDefiniteFailure())
1336  return WalkResult::interrupt();
1337  if (result.isSilenceableFailure()) {
1338  if (overallDiag.succeeded()) {
1339  overallDiag = emitSilenceableError() << "actions failed";
1340  }
1341  overallDiag.attachNote(action->getLoc())
1342  << "failed action: " << result.getMessage();
1343  overallDiag.attachNote(op->getLoc())
1344  << "when applied to this matching payload";
1345  (void)result.silence();
1346  continue;
1347  }
1348  }
1350  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1351  action.getFunctionBody().front().getTerminator()->getOperands(),
1352  state, getFlattenResults()))) {
1354  << "action @" << action.getName()
1355  << " has results associated with multiple payload entities, "
1356  "but flattening was not requested";
1357  return WalkResult::interrupt();
1358  }
1359  break;
1360  }
1361  return WalkResult::advance();
1362  });
1363  if (walkResult.wasInterrupted())
1365  }
1366 
1367  // The root operation should not have been affected, so we can just reassign
1368  // the payload to the result. Note that we need to consume the root handle to
1369  // make sure any handles to operations inside, that could have been affected
1370  // by actions, are invalidated.
1371  results.set(llvm::cast<OpResult>(getUpdated()),
1372  state.getPayloadOps(getRoot()));
1373  for (auto &&[result, mapping] :
1374  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1375  results.setMappedValues(result, mapping);
1376  }
1377  return overallDiag;
1378 }
1379 
1380 void transform::ForeachMatchOp::getAsmResultNames(
1381  OpAsmSetValueNameFn setNameFn) {
1382  setNameFn(getUpdated(), "updated_root");
1383  for (Value v : getForwardedOutputs()) {
1384  setNameFn(v, "yielded");
1385  }
1386 }
1387 
1388 void transform::ForeachMatchOp::getEffects(
1389  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1390  // Bail if invalid.
1391  if (getOperation()->getNumOperands() < 1 ||
1392  getOperation()->getNumResults() < 1) {
1393  return modifiesPayload(effects);
1394  }
1395 
1396  consumesHandle(getRootMutable(), effects);
1397  onlyReadsHandle(getForwardedInputsMutable(), effects);
1398  producesHandle(getOperation()->getOpResults(), effects);
1399  modifiesPayload(effects);
1400 }
1401 
1402 /// Parses the comma-separated list of symbol reference pairs of the format
1403 /// `@matcher -> @action`.
1404 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1405  ArrayAttr &matchers,
1406  ArrayAttr &actions) {
1407  StringAttr matcher;
1408  StringAttr action;
1409  SmallVector<Attribute> matcherList;
1410  SmallVector<Attribute> actionList;
1411  do {
1412  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1413  parser.parseSymbolName(action)) {
1414  return failure();
1415  }
1416  matcherList.push_back(SymbolRefAttr::get(matcher));
1417  actionList.push_back(SymbolRefAttr::get(action));
1418  } while (parser.parseOptionalComma().succeeded());
1419 
1420  matchers = parser.getBuilder().getArrayAttr(matcherList);
1421  actions = parser.getBuilder().getArrayAttr(actionList);
1422  return success();
1423 }
1424 
1425 /// Prints the comma-separated list of symbol reference pairs of the format
1426 /// `@matcher -> @action`.
1428  ArrayAttr matchers, ArrayAttr actions) {
1429  printer.increaseIndent();
1430  printer.increaseIndent();
1431  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1432  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1433  printer.printNewline();
1434  printer << cast<SymbolRefAttr>(matcher) << " -> "
1435  << cast<SymbolRefAttr>(action);
1436  if (idx != matchers.size() - 1)
1437  printer << ", ";
1438  }
1439  printer.decreaseIndent();
1440  printer.decreaseIndent();
1441 }
1442 
1443 LogicalResult transform::ForeachMatchOp::verify() {
1444  if (getMatchers().size() != getActions().size())
1445  return emitOpError() << "expected the same number of matchers and actions";
1446  if (getMatchers().empty())
1447  return emitOpError() << "expected at least one match/action pair";
1448 
1449  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1450  for (Attribute name : getMatchers()) {
1451  if (matcherNames.insert(name).second)
1452  continue;
1453  emitWarning() << "matcher " << name
1454  << " is used more than once, only the first match will apply";
1455  }
1456 
1457  return success();
1458 }
1459 
1460 /// Checks that the attributes of the function-like operation have correct
1461 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1462 /// annotations being present even if they can be inferred from the body.
1464 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1465  bool alsoVerifyInternal = false) {
1466  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1467  llvm::SmallDenseSet<unsigned> consumedArguments;
1468  if (!op.isExternal()) {
1469  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1470  consumedArguments);
1471  }
1472  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1473  bool isConsumed =
1474  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1475  nullptr;
1476  bool isReadOnly =
1477  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1478  nullptr;
1479  if (isConsumed && isReadOnly) {
1480  return transformOp.emitSilenceableError()
1481  << "argument #" << i << " cannot be both readonly and consumed";
1482  }
1483  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1484  return transformOp.emitSilenceableError()
1485  << "must provide consumed/readonly status for arguments of "
1486  "external or called ops";
1487  }
1488  if (op.isExternal())
1489  continue;
1490 
1491  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1492  return transformOp.emitSilenceableError()
1493  << "argument #" << i
1494  << " is consumed in the body but is not marked as such";
1495  }
1496  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1497  // Cannot use op.emitWarning() here as it would attempt to verify the op
1498  // before printing, resulting in infinite recursion.
1499  emitWarning(op->getLoc())
1500  << "op argument #" << i
1501  << " is not consumed in the body but is marked as consumed";
1502  }
1503  }
1505 }
1506 
1507 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1508  SymbolTableCollection &symbolTable) {
1509  assert(getMatchers().size() == getActions().size());
1510  auto consumedAttr =
1511  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1512  for (auto &&[matcher, action] :
1513  llvm::zip_equal(getMatchers(), getActions())) {
1514  // Presence and typing.
1515  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1516  symbolTable.lookupNearestSymbolFrom(getOperation(),
1517  cast<SymbolRefAttr>(matcher)));
1518  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1519  symbolTable.lookupNearestSymbolFrom(getOperation(),
1520  cast<SymbolRefAttr>(action)));
1521  if (!matcherSymbol ||
1522  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1523  return emitError() << "unresolved matcher symbol " << matcher;
1524  if (!actionSymbol ||
1525  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1526  return emitError() << "unresolved action symbol " << action;
1527 
1528  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1529  /*emitWarnings=*/false,
1530  /*alsoVerifyInternal=*/true)
1531  .checkAndReport())) {
1532  return failure();
1533  }
1535  /*emitWarnings=*/false,
1536  /*alsoVerifyInternal=*/true)
1537  .checkAndReport())) {
1538  return failure();
1539  }
1540 
1541  // Input -> matcher forwarding.
1542  TypeRange operandTypes = getOperandTypes();
1543  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1544  if (operandTypes.size() != matcherArguments.size()) {
1546  emitError() << "the number of operands (" << operandTypes.size()
1547  << ") doesn't match the number of matcher arguments ("
1548  << matcherArguments.size() << ") for " << matcher;
1549  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1550  return diag;
1551  }
1552  for (auto &&[i, operand, argument] :
1553  llvm::enumerate(operandTypes, matcherArguments)) {
1554  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1556  emitOpError()
1557  << "does not expect matcher symbol to consume its operand #" << i;
1558  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1559  return diag;
1560  }
1561 
1562  if (implementSameTransformInterface(operand, argument))
1563  continue;
1564 
1566  emitError()
1567  << "mismatching type interfaces for operand and matcher argument #"
1568  << i << " of matcher " << matcher;
1569  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1570  return diag;
1571  }
1572 
1573  // Matcher -> action forwarding.
1574  TypeRange matcherResults = matcherSymbol.getResultTypes();
1575  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1576  if (matcherResults.size() != actionArguments.size()) {
1577  return emitError() << "mismatching number of matcher results and "
1578  "action arguments between "
1579  << matcher << " (" << matcherResults.size() << ") and "
1580  << action << " (" << actionArguments.size() << ")";
1581  }
1582  for (auto &&[i, matcherType, actionType] :
1583  llvm::enumerate(matcherResults, actionArguments)) {
1584  if (implementSameTransformInterface(matcherType, actionType))
1585  continue;
1586 
1587  return emitError() << "mismatching type interfaces for matcher result "
1588  "and action argument #"
1589  << i << "of matcher " << matcher << " and action "
1590  << action;
1591  }
1592 
1593  // Action -> result forwarding.
1594  TypeRange actionResults = actionSymbol.getResultTypes();
1595  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1596  if (actionResults.size() != resultTypes.size()) {
1598  emitError() << "the number of action results ("
1599  << actionResults.size() << ") for " << action
1600  << " doesn't match the number of extra op results ("
1601  << resultTypes.size() << ")";
1602  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1603  return diag;
1604  }
1605  for (auto &&[i, resultType, actionType] :
1606  llvm::enumerate(resultTypes, actionResults)) {
1607  if (implementSameTransformInterface(resultType, actionType))
1608  continue;
1609 
1611  emitError() << "mismatching type interfaces for action result #" << i
1612  << " of action " << action << " and op result";
1613  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1614  return diag;
1615  }
1616  }
1617  return success();
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // ForeachOp
1622 //===----------------------------------------------------------------------===//
1623 
1625 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1626  transform::TransformResults &results,
1627  transform::TransformState &state) {
1628  // We store the payloads before executing the body as ops may be removed from
1629  // the mapping by the TrackingRewriter while iteration is in progress.
1631  detail::prepareValueMappings(payloads, getTargets(), state);
1632  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1633  bool withZipShortest = getWithZipShortest();
1634 
1635  // In case of `zip_shortest`, set the number of iterations to the
1636  // smallest payload in the targets.
1637  if (withZipShortest) {
1638  numIterations =
1639  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1640  const SmallVector<MappedValue> &b) {
1641  return a.size() < b.size();
1642  })->size();
1643 
1644  for (auto &payload : payloads)
1645  payload.resize(numIterations);
1646  }
1647 
1648  // As we will be "zipping" over them, check all payloads have the same size.
1649  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1650  // when true.
1651  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1652  argIdx++) {
1653  if (payloads[argIdx].size() != numIterations) {
1654  return emitSilenceableError()
1655  << "prior targets' payload size (" << numIterations
1656  << ") differs from payload size (" << payloads[argIdx].size()
1657  << ") of target " << getTargets()[argIdx];
1658  }
1659  }
1660 
1661  // Start iterating, indexing into payloads to obtain the right arguments to
1662  // call the body with - each slice of payloads at the same argument index
1663  // corresponding to a tuple to use as the body's block arguments.
1664  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1665  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1666  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1667  auto scope = state.make_region_scope(getBody());
1668  // Set up arguments to the region's block.
1669  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1670  MappedValue argument = payloads[argIdx][iterIdx];
1671  // Note that each blockArg's handle gets associated with just a single
1672  // element from the corresponding target's payload.
1673  if (failed(state.mapBlockArgument(blockArg, {argument})))
1675  }
1676 
1677  // Execute loop body.
1678  for (Operation &transform : getBody().front().without_terminator()) {
1679  DiagnosedSilenceableFailure result = state.applyTransform(
1680  llvm::cast<transform::TransformOpInterface>(transform));
1681  if (!result.succeeded())
1682  return result;
1683  }
1684 
1685  // Append yielded payloads to corresponding results from prior iterations.
1686  OperandRange yieldOperands = getYieldOp().getOperands();
1687  for (auto &&[result, yieldOperand, resTuple] :
1688  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1689  // NB: each iteration we add any number of ops/vals/params to a result.
1690  if (isa<TransformHandleTypeInterface>(result.getType()))
1691  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1692  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1693  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1694  else if (isa<TransformParamTypeInterface>(result.getType()))
1695  llvm::append_range(resTuple, state.getParams(yieldOperand));
1696  else
1697  assert(false && "unhandled handle type");
1698  }
1699 
1700  // Associate the accumulated result payloads to the op's actual results.
1701  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1702  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1703 
1705 }
1706 
1707 void transform::ForeachOp::getEffects(
1708  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1709  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1710  // arity errors, this method might get called before/in absence of `verify()`.
1711  for (auto &&[target, blockArg] :
1712  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1713  BlockArgument blockArgument = blockArg;
1714  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1715  return isHandleConsumed(blockArgument,
1716  cast<TransformOpInterface>(&op));
1717  })) {
1718  consumesHandle(target, effects);
1719  } else {
1720  onlyReadsHandle(target, effects);
1721  }
1722  }
1723 
1724  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1725  return doesModifyPayload(cast<TransformOpInterface>(&op));
1726  })) {
1727  modifiesPayload(effects);
1728  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1729  return doesReadPayload(cast<TransformOpInterface>(&op));
1730  })) {
1731  onlyReadsPayload(effects);
1732  }
1733 
1734  producesHandle(getOperation()->getOpResults(), effects);
1735 }
1736 
1737 void transform::ForeachOp::getSuccessorRegions(
1738  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1739  Region *bodyRegion = &getBody();
1740  if (point.isParent()) {
1741  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1742  return;
1743  }
1744 
1745  // Branch back to the region or the parent.
1746  assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
1747  &getBody() &&
1748  "unexpected region index");
1749  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1750  regions.emplace_back(getOperation(), getOperation()->getResults());
1751 }
1752 
1754 transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1755  // Each block argument handle is mapped to a subset (one op to be precise)
1756  // of the payload of the corresponding `targets` operand of ForeachOp.
1757  assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1758  return getOperation()->getOperands();
1759 }
1760 
1761 transform::YieldOp transform::ForeachOp::getYieldOp() {
1762  return cast<transform::YieldOp>(getBody().front().getTerminator());
1763 }
1764 
1765 LogicalResult transform::ForeachOp::verify() {
1766  for (auto [targetOpt, bodyArgOpt] :
1767  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1768  if (!targetOpt || !bodyArgOpt)
1769  return emitOpError() << "expects the same number of targets as the body "
1770  "has block arguments";
1771  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1772  return emitOpError(
1773  "expects co-indexed targets and the body's "
1774  "block arguments to have the same op/value/param type");
1775  }
1776 
1777  for (auto [resultOpt, yieldOperandOpt] :
1778  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1779  if (!resultOpt || !yieldOperandOpt)
1780  return emitOpError() << "expects the same number of results as the "
1781  "yield terminator has operands";
1782  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1783  return emitOpError("expects co-indexed results and yield "
1784  "operands to have the same op/value/param type");
1785  }
1786 
1787  return success();
1788 }
1789 
1790 //===----------------------------------------------------------------------===//
1791 // GetParentOp
1792 //===----------------------------------------------------------------------===//
1793 
1795 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1796  transform::TransformResults &results,
1797  transform::TransformState &state) {
1798  SmallVector<Operation *> parents;
1799  DenseSet<Operation *> resultSet;
1800  for (Operation *target : state.getPayloadOps(getTarget())) {
1801  Operation *parent = target;
1802  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1803  parent = parent->getParentOp();
1804  while (parent) {
1805  bool checkIsolatedFromAbove =
1806  !getIsolatedFromAbove() ||
1808  bool checkOpName = !getOpName().has_value() ||
1809  parent->getName().getStringRef() == *getOpName();
1810  if (checkIsolatedFromAbove && checkOpName)
1811  break;
1812  parent = parent->getParentOp();
1813  }
1814  if (!parent) {
1815  if (getAllowEmptyResults()) {
1816  results.set(llvm::cast<OpResult>(getResult()), parents);
1818  }
1820  emitSilenceableError()
1821  << "could not find a parent op that matches all requirements";
1822  diag.attachNote(target->getLoc()) << "target op";
1823  return diag;
1824  }
1825  }
1826  if (getDeduplicate()) {
1827  if (resultSet.insert(parent).second)
1828  parents.push_back(parent);
1829  } else {
1830  parents.push_back(parent);
1831  }
1832  }
1833  results.set(llvm::cast<OpResult>(getResult()), parents);
1835 }
1836 
1837 //===----------------------------------------------------------------------===//
1838 // GetConsumersOfResult
1839 //===----------------------------------------------------------------------===//
1840 
1842 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1843  transform::TransformResults &results,
1844  transform::TransformState &state) {
1845  int64_t resultNumber = getResultNumber();
1846  auto payloadOps = state.getPayloadOps(getTarget());
1847  if (std::empty(payloadOps)) {
1848  results.set(cast<OpResult>(getResult()), {});
1850  }
1851  if (!llvm::hasSingleElement(payloadOps))
1852  return emitDefiniteFailure()
1853  << "handle must be mapped to exactly one payload op";
1854 
1855  Operation *target = *payloadOps.begin();
1856  if (target->getNumResults() <= resultNumber)
1857  return emitDefiniteFailure() << "result number overflow";
1858  results.set(llvm::cast<OpResult>(getResult()),
1859  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1861 }
1862 
1863 //===----------------------------------------------------------------------===//
1864 // GetDefiningOp
1865 //===----------------------------------------------------------------------===//
1866 
1868 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1869  transform::TransformResults &results,
1870  transform::TransformState &state) {
1871  SmallVector<Operation *> definingOps;
1872  for (Value v : state.getPayloadValues(getTarget())) {
1873  if (llvm::isa<BlockArgument>(v)) {
1875  emitSilenceableError() << "cannot get defining op of block argument";
1876  diag.attachNote(v.getLoc()) << "target value";
1877  return diag;
1878  }
1879  definingOps.push_back(v.getDefiningOp());
1880  }
1881  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // GetProducerOfOperand
1887 //===----------------------------------------------------------------------===//
1888 
1890 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1891  transform::TransformResults &results,
1892  transform::TransformState &state) {
1893  int64_t operandNumber = getOperandNumber();
1894  SmallVector<Operation *> producers;
1895  for (Operation *target : state.getPayloadOps(getTarget())) {
1896  Operation *producer =
1897  target->getNumOperands() <= operandNumber
1898  ? nullptr
1899  : target->getOperand(operandNumber).getDefiningOp();
1900  if (!producer) {
1902  emitSilenceableError()
1903  << "could not find a producer for operand number: " << operandNumber
1904  << " of " << *target;
1905  diag.attachNote(target->getLoc()) << "target op";
1906  return diag;
1907  }
1908  producers.push_back(producer);
1909  }
1910  results.set(llvm::cast<OpResult>(getResult()), producers);
1912 }
1913 
1914 //===----------------------------------------------------------------------===//
1915 // GetOperandOp
1916 //===----------------------------------------------------------------------===//
1917 
1919 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1920  transform::TransformResults &results,
1921  transform::TransformState &state) {
1922  SmallVector<Value> operands;
1923  for (Operation *target : state.getPayloadOps(getTarget())) {
1924  SmallVector<int64_t> operandPositions;
1926  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1927  target->getNumOperands(), operandPositions);
1928  if (diag.isSilenceableFailure()) {
1929  diag.attachNote(target->getLoc())
1930  << "while considering positions of this payload operation";
1931  return diag;
1932  }
1933  llvm::append_range(operands,
1934  llvm::map_range(operandPositions, [&](int64_t pos) {
1935  return target->getOperand(pos);
1936  }));
1937  }
1938  results.setValues(cast<OpResult>(getResult()), operands);
1940 }
1941 
1942 LogicalResult transform::GetOperandOp::verify() {
1943  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1944  getIsInverted(), getIsAll());
1945 }
1946 
1947 //===----------------------------------------------------------------------===//
1948 // GetResultOp
1949 //===----------------------------------------------------------------------===//
1950 
1952 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1953  transform::TransformResults &results,
1954  transform::TransformState &state) {
1955  SmallVector<Value> opResults;
1956  for (Operation *target : state.getPayloadOps(getTarget())) {
1957  SmallVector<int64_t> resultPositions;
1959  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1960  target->getNumResults(), resultPositions);
1961  if (diag.isSilenceableFailure()) {
1962  diag.attachNote(target->getLoc())
1963  << "while considering positions of this payload operation";
1964  return diag;
1965  }
1966  llvm::append_range(opResults,
1967  llvm::map_range(resultPositions, [&](int64_t pos) {
1968  return target->getResult(pos);
1969  }));
1970  }
1971  results.setValues(cast<OpResult>(getResult()), opResults);
1973 }
1974 
1975 LogicalResult transform::GetResultOp::verify() {
1976  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1977  getIsInverted(), getIsAll());
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // GetTypeOp
1982 //===----------------------------------------------------------------------===//
1983 
1984 void transform::GetTypeOp::getEffects(
1985  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1986  onlyReadsHandle(getValueMutable(), effects);
1987  producesHandle(getOperation()->getOpResults(), effects);
1988  onlyReadsPayload(effects);
1989 }
1990 
1992 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1993  transform::TransformResults &results,
1994  transform::TransformState &state) {
1995  SmallVector<Attribute> params;
1996  for (Value value : state.getPayloadValues(getValue())) {
1997  Type type = value.getType();
1998  if (getElemental()) {
1999  if (auto shaped = dyn_cast<ShapedType>(type)) {
2000  type = shaped.getElementType();
2001  }
2002  }
2003  params.push_back(TypeAttr::get(type));
2004  }
2005  results.setParams(cast<OpResult>(getResult()), params);
2007 }
2008 
2009 //===----------------------------------------------------------------------===//
2010 // IncludeOp
2011 //===----------------------------------------------------------------------===//
2012 
2013 /// Applies the transform ops contained in `block`. Maps `results` to the same
2014 /// values as the operands of the block terminator.
2016 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2018  transform::TransformResults &results) {
2019  // Apply the sequenced ops one by one.
2020  for (Operation &transform : block.without_terminator()) {
2022  state.applyTransform(cast<transform::TransformOpInterface>(transform));
2023  if (result.isDefiniteFailure())
2024  return result;
2025 
2026  if (result.isSilenceableFailure()) {
2027  if (mode == transform::FailurePropagationMode::Propagate) {
2028  // Propagate empty results in case of early exit.
2029  forwardEmptyOperands(&block, state, results);
2030  return result;
2031  }
2032  (void)result.silence();
2033  }
2034  }
2035 
2036  // Forward the operation mapping for values yielded from the sequence to the
2037  // values produced by the sequence op.
2038  transform::detail::forwardTerminatorOperands(&block, state, results);
2040 }
2041 
2043 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2044  transform::TransformResults &results,
2045  transform::TransformState &state) {
2046  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2047  getOperation(), getTarget());
2048  assert(callee && "unverified reference to unknown symbol");
2049 
2050  if (callee.isExternal())
2051  return emitDefiniteFailure() << "unresolved external named sequence";
2052 
2053  // Map operands to block arguments.
2055  detail::prepareValueMappings(mappings, getOperands(), state);
2056  auto scope = state.make_region_scope(callee.getBody());
2057  for (auto &&[arg, map] :
2058  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2059  if (failed(state.mapBlockArgument(arg, map)))
2061  }
2062 
2064  callee.getBody().front(), getFailurePropagationMode(), state, results);
2065  mappings.clear();
2067  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2068  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2069  results.setMappedValues(result, mapping);
2070  return result;
2071 }
2072 
2074 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2075 
2076 void transform::IncludeOp::getEffects(
2077  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2078  // Always mark as modifying the payload.
2079  // TODO: a mechanism to annotate effects on payload. Even when all handles are
2080  // only read, the payload may still be modified, so we currently stay on the
2081  // conservative side and always indicate modification. This may prevent some
2082  // code reordering.
2083  modifiesPayload(effects);
2084 
2085  // Results are always produced.
2086  producesHandle(getOperation()->getOpResults(), effects);
2087 
2088  // Adds default effects to operands and results. This will be added if
2089  // preconditions fail so the trait verifier doesn't complain about missing
2090  // effects and the real precondition failure is reported later on.
2091  auto defaultEffects = [&] {
2092  onlyReadsHandle(getOperation()->getOpOperands(), effects);
2093  };
2094 
2095  // Bail if the callee is unknown. This may run as part of the verification
2096  // process before we verified the validity of the callee or of this op.
2097  auto target =
2098  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2099  if (!target)
2100  return defaultEffects();
2101  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2102  getOperation(), getTarget());
2103  if (!callee)
2104  return defaultEffects();
2105 
2106  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2107  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2108  consumesHandle(getOperation()->getOpOperand(i), effects);
2109  else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2110  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2111  }
2112 }
2113 
2114 LogicalResult
2115 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2116  // Access through indirection and do additional checking because this may be
2117  // running before the main op verifier.
2118  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2119  if (!targetAttr)
2120  return emitOpError() << "expects a 'target' symbol reference attribute";
2121 
2122  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2123  *this, targetAttr);
2124  if (!target)
2125  return emitOpError() << "does not reference a named transform sequence";
2126 
2127  FunctionType fnType = target.getFunctionType();
2128  if (fnType.getNumInputs() != getNumOperands())
2129  return emitError("incorrect number of operands for callee");
2130 
2131  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2132  if (getOperand(i).getType() != fnType.getInput(i)) {
2133  return emitOpError("operand type mismatch: expected operand type ")
2134  << fnType.getInput(i) << ", but provided "
2135  << getOperand(i).getType() << " for operand number " << i;
2136  }
2137  }
2138 
2139  if (fnType.getNumResults() != getNumResults())
2140  return emitError("incorrect number of results for callee");
2141 
2142  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2143  Type resultType = getResult(i).getType();
2144  Type funcType = fnType.getResult(i);
2145  if (!implementSameTransformInterface(resultType, funcType)) {
2146  return emitOpError() << "type of result #" << i
2147  << " must implement the same transform dialect "
2148  "interface as the corresponding callee result";
2149  }
2150  }
2151 
2153  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2154  /*alsoVerifyInternal=*/true)
2155  .checkAndReport();
2156 }
2157 
2158 //===----------------------------------------------------------------------===//
2159 // MatchOperationEmptyOp
2160 //===----------------------------------------------------------------------===//
2161 
2162 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2163  ::std::optional<::mlir::Operation *> maybeCurrent,
2165  if (!maybeCurrent.has_value()) {
2166  LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2168  }
2169  LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2170  return emitSilenceableError() << "operation is not empty";
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // MatchOperationNameOp
2175 //===----------------------------------------------------------------------===//
2176 
2177 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2178  Operation *current, transform::TransformResults &results,
2179  transform::TransformState &state) {
2180  StringRef currentOpName = current->getName().getStringRef();
2181  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2182  if (acceptedAttr.getValue() == currentOpName)
2184  }
2185  return emitSilenceableError() << "wrong operation name";
2186 }
2187 
2188 //===----------------------------------------------------------------------===//
2189 // MatchParamCmpIOp
2190 //===----------------------------------------------------------------------===//
2191 
2193 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2194  transform::TransformResults &results,
2195  transform::TransformState &state) {
2196  auto signedAPIntAsString = [&](const APInt &value) {
2197  std::string str;
2198  llvm::raw_string_ostream os(str);
2199  value.print(os, /*isSigned=*/true);
2200  return str;
2201  };
2202 
2203  ArrayRef<Attribute> params = state.getParams(getParam());
2204  ArrayRef<Attribute> references = state.getParams(getReference());
2205 
2206  if (params.size() != references.size()) {
2207  return emitSilenceableError()
2208  << "parameters have different payload lengths (" << params.size()
2209  << " vs " << references.size() << ")";
2210  }
2211 
2212  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2213  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2214  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2215  if (!intAttr || !refAttr) {
2216  return emitDefiniteFailure()
2217  << "non-integer parameter value not expected";
2218  }
2219  if (intAttr.getType() != refAttr.getType()) {
2220  return emitDefiniteFailure()
2221  << "mismatching integer attribute types in parameter #" << i;
2222  }
2223  APInt value = intAttr.getValue();
2224  APInt refValue = refAttr.getValue();
2225 
2226  // TODO: this copy will not be necessary in C++20.
2227  int64_t position = i;
2228  auto reportError = [&](StringRef direction) {
2230  emitSilenceableError() << "expected parameter to be " << direction
2231  << " " << signedAPIntAsString(refValue)
2232  << ", got " << signedAPIntAsString(value);
2233  diag.attachNote(getParam().getLoc())
2234  << "value # " << position
2235  << " associated with the parameter defined here";
2236  return diag;
2237  };
2238 
2239  switch (getPredicate()) {
2240  case MatchCmpIPredicate::eq:
2241  if (value.eq(refValue))
2242  break;
2243  return reportError("equal to");
2244  case MatchCmpIPredicate::ne:
2245  if (value.ne(refValue))
2246  break;
2247  return reportError("not equal to");
2248  case MatchCmpIPredicate::lt:
2249  if (value.slt(refValue))
2250  break;
2251  return reportError("less than");
2252  case MatchCmpIPredicate::le:
2253  if (value.sle(refValue))
2254  break;
2255  return reportError("less than or equal to");
2256  case MatchCmpIPredicate::gt:
2257  if (value.sgt(refValue))
2258  break;
2259  return reportError("greater than");
2260  case MatchCmpIPredicate::ge:
2261  if (value.sge(refValue))
2262  break;
2263  return reportError("greater than or equal to");
2264  }
2265  }
2267 }
2268 
2269 void transform::MatchParamCmpIOp::getEffects(
2270  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2271  onlyReadsHandle(getParamMutable(), effects);
2272  onlyReadsHandle(getReferenceMutable(), effects);
2273 }
2274 
2275 //===----------------------------------------------------------------------===//
2276 // ParamConstantOp
2277 //===----------------------------------------------------------------------===//
2278 
2280 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2281  transform::TransformResults &results,
2282  transform::TransformState &state) {
2283  results.setParams(cast<OpResult>(getParam()), {getValue()});
2285 }
2286 
2287 //===----------------------------------------------------------------------===//
2288 // MergeHandlesOp
2289 //===----------------------------------------------------------------------===//
2290 
2292 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2293  transform::TransformResults &results,
2294  transform::TransformState &state) {
2295  ValueRange handles = getHandles();
2296  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2297  SmallVector<Operation *> operations;
2298  for (Value operand : handles)
2299  llvm::append_range(operations, state.getPayloadOps(operand));
2300  if (!getDeduplicate()) {
2301  results.set(llvm::cast<OpResult>(getResult()), operations);
2303  }
2304 
2305  SetVector<Operation *> uniqued(llvm::from_range, operations);
2306  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2308  }
2309 
2310  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2311  SmallVector<Attribute> attrs;
2312  for (Value attribute : handles)
2313  llvm::append_range(attrs, state.getParams(attribute));
2314  if (!getDeduplicate()) {
2315  results.setParams(cast<OpResult>(getResult()), attrs);
2317  }
2318 
2319  SetVector<Attribute> uniqued(llvm::from_range, attrs);
2320  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2322  }
2323 
2324  assert(
2325  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2326  "expected value handle type");
2327  SmallVector<Value> payloadValues;
2328  for (Value value : handles)
2329  llvm::append_range(payloadValues, state.getPayloadValues(value));
2330  if (!getDeduplicate()) {
2331  results.setValues(cast<OpResult>(getResult()), payloadValues);
2333  }
2334 
2335  SetVector<Value> uniqued(llvm::from_range, payloadValues);
2336  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2338 }
2339 
2340 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2341  // Handles may be the same if deduplicating is enabled.
2342  return getDeduplicate();
2343 }
2344 
2345 void transform::MergeHandlesOp::getEffects(
2346  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2347  onlyReadsHandle(getHandlesMutable(), effects);
2348  producesHandle(getOperation()->getOpResults(), effects);
2349 
2350  // There are no effects on the Payload IR as this is only a handle
2351  // manipulation.
2352 }
2353 
2354 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2355  if (getDeduplicate() || getHandles().size() != 1)
2356  return {};
2357 
2358  // If deduplication is not required and there is only one operand, it can be
2359  // used directly instead of merging.
2360  return getHandles().front();
2361 }
2362 
2363 //===----------------------------------------------------------------------===//
2364 // NamedSequenceOp
2365 //===----------------------------------------------------------------------===//
2366 
2368 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2369  transform::TransformResults &results,
2370  transform::TransformState &state) {
2371  if (isExternal())
2372  return emitDefiniteFailure() << "unresolved external named sequence";
2373 
2374  // Map the entry block argument to the list of operations.
2375  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2376  // without attaching the interface / trait since that is tailored to a
2377  // dangling top-level op that does not get "called".
2378  auto scope = state.make_region_scope(getBody());
2380  state, this->getOperation(), getBody())))
2382 
2383  return applySequenceBlock(getBody().front(),
2384  FailurePropagationMode::Propagate, state, results);
2385 }
2386 
2387 void transform::NamedSequenceOp::getEffects(
2388  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2389 
2391  OperationState &result) {
2393  parser, result, /*allowVariadic=*/false,
2394  getFunctionTypeAttrName(result.name),
2395  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2397  std::string &) { return builder.getFunctionType(inputs, results); },
2398  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2399 }
2400 
2403  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2404  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2405  getResAttrsAttrName());
2406 }
2407 
2408 /// Verifies that a symbol function-like transform dialect operation has the
2409 /// signature and the terminator that have conforming types, i.e., types
2410 /// implementing the same transform dialect type interface. If `allowExternal`
2411 /// is set, allow external symbols (declarations) and don't check the terminator
2412 /// as it may not exist.
2414 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2415  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2418  << "cannot be defined inside another transform op";
2419  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2420  return diag;
2421  }
2422 
2423  if (op.isExternal() || op.getFunctionBody().empty()) {
2424  if (allowExternal)
2426 
2427  return emitSilenceableFailure(op) << "cannot be external";
2428  }
2429 
2430  if (op.getFunctionBody().front().empty())
2431  return emitSilenceableFailure(op) << "expected a non-empty body block";
2432 
2433  Operation *terminator = &op.getFunctionBody().front().back();
2434  if (!isa<transform::YieldOp>(terminator)) {
2436  << "expected '"
2437  << transform::YieldOp::getOperationName()
2438  << "' as terminator";
2439  diag.attachNote(terminator->getLoc()) << "terminator";
2440  return diag;
2441  }
2442 
2443  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2444  return emitSilenceableFailure(terminator)
2445  << "expected terminator to have as many operands as the parent op "
2446  "has results";
2447  }
2448  for (auto [i, operandType, resultType] : llvm::zip_equal(
2449  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2450  terminator->getOperands().getType(), op.getResultTypes())) {
2451  if (operandType == resultType)
2452  continue;
2453  return emitSilenceableFailure(terminator)
2454  << "the type of the terminator operand #" << i
2455  << " must match the type of the corresponding parent op result ("
2456  << operandType << " vs " << resultType << ")";
2457  }
2458 
2460 }
2461 
2462 /// Verification of a NamedSequenceOp. This does not report the error
2463 /// immediately, so it can be used to check for op's well-formedness before the
2464 /// verifier runs, e.g., during trait verification.
2466 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2467  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2468  if (!parent->getAttr(
2469  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2472  << "expects the parent symbol table to have the '"
2473  << transform::TransformDialect::kWithNamedSequenceAttrName
2474  << "' attribute";
2475  diag.attachNote(parent->getLoc()) << "symbol table operation";
2476  return diag;
2477  }
2478  }
2479 
2480  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2483  << "cannot be defined inside another transform op";
2484  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2485  return diag;
2486  }
2487 
2488  if (op.isExternal() || op.getBody().empty())
2489  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2490  emitWarnings);
2491 
2492  if (op.getBody().front().empty())
2493  return emitSilenceableFailure(op) << "expected a non-empty body block";
2494 
2495  Operation *terminator = &op.getBody().front().back();
2496  if (!isa<transform::YieldOp>(terminator)) {
2498  << "expected '"
2499  << transform::YieldOp::getOperationName()
2500  << "' as terminator";
2501  diag.attachNote(terminator->getLoc()) << "terminator";
2502  return diag;
2503  }
2504 
2505  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2506  return emitSilenceableFailure(terminator)
2507  << "expected terminator to have as many operands as the parent op "
2508  "has results";
2509  }
2510  for (auto [i, operandType, resultType] :
2511  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2512  terminator->getOperands().getType(),
2513  op.getFunctionType().getResults())) {
2514  if (operandType == resultType)
2515  continue;
2516  return emitSilenceableFailure(terminator)
2517  << "the type of the terminator operand #" << i
2518  << " must match the type of the corresponding parent op result ("
2519  << operandType << " vs " << resultType << ")";
2520  }
2521 
2522  auto funcOp = cast<FunctionOpInterface>(*op);
2524  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2525  if (!diag.succeeded())
2526  return diag;
2527 
2528  return verifyYieldingSingleBlockOp(funcOp,
2529  /*allowExternal=*/true);
2530 }
2531 
2532 LogicalResult transform::NamedSequenceOp::verify() {
2533  // Actual verification happens in a separate function for reusability.
2534  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2535 }
2536 
2537 template <typename FnTy>
2538 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2539  Type bbArgType, TypeRange extraBindingTypes,
2540  FnTy bodyBuilder) {
2541  SmallVector<Type> types;
2542  types.reserve(1 + extraBindingTypes.size());
2543  types.push_back(bbArgType);
2544  llvm::append_range(types, extraBindingTypes);
2545 
2546  OpBuilder::InsertionGuard guard(builder);
2547  Region *region = state.regions.back().get();
2548  Block *bodyBlock =
2549  builder.createBlock(region, region->begin(), types,
2550  SmallVector<Location>(types.size(), state.location));
2551 
2552  // Populate body.
2553  builder.setInsertionPointToStart(bodyBlock);
2554  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2555  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2556  } else {
2557  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2558  bodyBlock->getArguments().drop_front());
2559  }
2560 }
2561 
2562 void transform::NamedSequenceOp::build(OpBuilder &builder,
2563  OperationState &state, StringRef symName,
2564  Type rootType, TypeRange resultTypes,
2565  SequenceBodyBuilderFn bodyBuilder,
2567  ArrayRef<DictionaryAttr> argAttrs) {
2568  state.addAttribute(SymbolTable::getSymbolAttrName(),
2569  builder.getStringAttr(symName));
2570  state.addAttribute(getFunctionTypeAttrName(state.name),
2572  rootType, resultTypes)));
2573  state.attributes.append(attrs.begin(), attrs.end());
2574  state.addRegion();
2575 
2576  buildSequenceBody(builder, state, rootType,
2577  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2578 }
2579 
2580 //===----------------------------------------------------------------------===//
2581 // NumAssociationsOp
2582 //===----------------------------------------------------------------------===//
2583 
2585 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2586  transform::TransformResults &results,
2587  transform::TransformState &state) {
2588  size_t numAssociations =
2590  .Case([&](TransformHandleTypeInterface opHandle) {
2591  return llvm::range_size(state.getPayloadOps(getHandle()));
2592  })
2593  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2594  return llvm::range_size(state.getPayloadValues(getHandle()));
2595  })
2596  .Case([&](TransformParamTypeInterface param) {
2597  return llvm::range_size(state.getParams(getHandle()));
2598  })
2599  .DefaultUnreachable("unknown kind of transform dialect type");
2600  results.setParams(cast<OpResult>(getNum()),
2601  rewriter.getI64IntegerAttr(numAssociations));
2603 }
2604 
2605 LogicalResult transform::NumAssociationsOp::verify() {
2606  // Verify that the result type accepts an i64 attribute as payload.
2607  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2608  return resultType
2609  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2610  .checkAndReport();
2611 }
2612 
2613 //===----------------------------------------------------------------------===//
2614 // SelectOp
2615 //===----------------------------------------------------------------------===//
2616 
2618 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2619  transform::TransformResults &results,
2620  transform::TransformState &state) {
2621  SmallVector<Operation *> result;
2622  auto payloadOps = state.getPayloadOps(getTarget());
2623  for (Operation *op : payloadOps) {
2624  if (op->getName().getStringRef() == getOpName())
2625  result.push_back(op);
2626  }
2627  results.set(cast<OpResult>(getResult()), result);
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // SplitHandleOp
2633 //===----------------------------------------------------------------------===//
2634 
2635 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2636  Value target, int64_t numResultHandles) {
2637  result.addOperands(target);
2638  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2639 }
2640 
2642 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2643  transform::TransformResults &results,
2644  transform::TransformState &state) {
2645  int64_t numPayloads =
2647  .Case<TransformHandleTypeInterface>([&](auto x) {
2648  return llvm::range_size(state.getPayloadOps(getHandle()));
2649  })
2650  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2651  return llvm::range_size(state.getPayloadValues(getHandle()));
2652  })
2653  .Case<TransformParamTypeInterface>([&](auto x) {
2654  return llvm::range_size(state.getParams(getHandle()));
2655  })
2656  .DefaultUnreachable("unknown transform dialect type interface");
2657 
2658  auto produceNumOpsError = [&]() {
2659  return emitSilenceableError()
2660  << getHandle() << " expected to contain " << this->getNumResults()
2661  << " payloads but it contains " << numPayloads << " payloads";
2662  };
2663 
2664  // Fail if there are more payload ops than results and no overflow result was
2665  // specified.
2666  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2667  return produceNumOpsError();
2668 
2669  // Fail if there are more results than payload ops. Unless:
2670  // - "fail_on_payload_too_small" is set to "false", or
2671  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2672  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2673  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2674  return produceNumOpsError();
2675 
2676  // Distribute payloads.
2677  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2678  if (getOverflowResult())
2679  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2680 
2681  auto container = [&]() {
2682  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2683  return llvm::map_to_vector(
2684  state.getPayloadOps(getHandle()),
2685  [](Operation *op) -> MappedValue { return op; });
2686  }
2687  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2688  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2689  [](Value v) -> MappedValue { return v; });
2690  }
2691  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2692  "unsupported kind of transform dialect type");
2693  return llvm::map_to_vector(state.getParams(getHandle()),
2694  [](Attribute a) -> MappedValue { return a; });
2695  }();
2696 
2697  for (auto &&en : llvm::enumerate(container)) {
2698  int64_t resultNum = en.index();
2699  if (resultNum >= getNumResults())
2700  resultNum = *getOverflowResult();
2701  resultHandles[resultNum].push_back(en.value());
2702  }
2703 
2704  // Set transform op results.
2705  for (auto &&it : llvm::enumerate(resultHandles))
2706  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2707  it.value());
2708 
2710 }
2711 
2712 void transform::SplitHandleOp::getEffects(
2713  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2714  onlyReadsHandle(getHandleMutable(), effects);
2715  producesHandle(getOperation()->getOpResults(), effects);
2716  // There are no effects on the Payload IR as this is only a handle
2717  // manipulation.
2718 }
2719 
2720 LogicalResult transform::SplitHandleOp::verify() {
2721  if (getOverflowResult().has_value() &&
2722  !(*getOverflowResult() < getNumResults()))
2723  return emitOpError("overflow_result is not a valid result index");
2724 
2725  for (Type resultType : getResultTypes()) {
2726  if (implementSameTransformInterface(getHandle().getType(), resultType))
2727  continue;
2728 
2729  return emitOpError("expects result types to implement the same transform "
2730  "interface as the operand type");
2731  }
2732 
2733  return success();
2734 }
2735 
2736 //===----------------------------------------------------------------------===//
2737 // ReplicateOp
2738 //===----------------------------------------------------------------------===//
2739 
2741 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2742  transform::TransformResults &results,
2743  transform::TransformState &state) {
2744  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2745  for (const auto &en : llvm::enumerate(getHandles())) {
2746  Value handle = en.value();
2747  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2748  SmallVector<Operation *> current =
2749  llvm::to_vector(state.getPayloadOps(handle));
2750  SmallVector<Operation *> payload;
2751  payload.reserve(numRepetitions * current.size());
2752  for (unsigned i = 0; i < numRepetitions; ++i)
2753  llvm::append_range(payload, current);
2754  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2755  } else {
2756  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2757  "expected param type");
2758  ArrayRef<Attribute> current = state.getParams(handle);
2759  SmallVector<Attribute> params;
2760  params.reserve(numRepetitions * current.size());
2761  for (unsigned i = 0; i < numRepetitions; ++i)
2762  llvm::append_range(params, current);
2763  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2764  params);
2765  }
2766  }
2768 }
2769 
2770 void transform::ReplicateOp::getEffects(
2771  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2772  onlyReadsHandle(getPatternMutable(), effects);
2773  onlyReadsHandle(getHandlesMutable(), effects);
2774  producesHandle(getOperation()->getOpResults(), effects);
2775 }
2776 
2777 //===----------------------------------------------------------------------===//
2778 // SequenceOp
2779 //===----------------------------------------------------------------------===//
2780 
2782 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2783  transform::TransformResults &results,
2784  transform::TransformState &state) {
2785  // Map the entry block argument to the list of operations.
2786  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2787  if (failed(mapBlockArguments(state)))
2789 
2790  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2791  results);
2792 }
2793 
2794 static ParseResult parseSequenceOpOperands(
2795  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2796  Type &rootType,
2797  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2798  SmallVectorImpl<Type> &extraBindingTypes) {
2799  OpAsmParser::UnresolvedOperand rootOperand;
2800  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2801  if (!hasRoot.has_value()) {
2802  root = std::nullopt;
2803  return success();
2804  }
2805  if (failed(hasRoot.value()))
2806  return failure();
2807  root = rootOperand;
2808 
2809  if (succeeded(parser.parseOptionalComma())) {
2810  if (failed(parser.parseOperandList(extraBindings)))
2811  return failure();
2812  }
2813  if (failed(parser.parseColon()))
2814  return failure();
2815 
2816  // The paren is truly optional.
2817  (void)parser.parseOptionalLParen();
2818 
2819  if (failed(parser.parseType(rootType))) {
2820  return failure();
2821  }
2822 
2823  if (!extraBindings.empty()) {
2824  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2825  return failure();
2826  }
2827 
2828  if (extraBindingTypes.size() != extraBindings.size()) {
2829  return parser.emitError(parser.getNameLoc(),
2830  "expected types to be provided for all operands");
2831  }
2832 
2833  // The paren is truly optional.
2834  (void)parser.parseOptionalRParen();
2835  return success();
2836 }
2837 
2839  Value root, Type rootType,
2840  ValueRange extraBindings,
2841  TypeRange extraBindingTypes) {
2842  if (!root)
2843  return;
2844 
2845  printer << root;
2846  bool hasExtras = !extraBindings.empty();
2847  if (hasExtras) {
2848  printer << ", ";
2849  printer.printOperands(extraBindings);
2850  }
2851 
2852  printer << " : ";
2853  if (hasExtras)
2854  printer << "(";
2855 
2856  printer << rootType;
2857  if (hasExtras)
2858  printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2859 }
2860 
2861 /// Returns `true` if the given op operand may be consuming the handle value in
2862 /// the Transform IR. That is, if it may have a Free effect on it.
2864  // Conservatively assume the effect being present in absence of the interface.
2865  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2866  if (!iface)
2867  return true;
2868 
2869  return isHandleConsumed(use.get(), iface);
2870 }
2871 
2872 LogicalResult
2874  function_ref<InFlightDiagnostic()> reportError) {
2875  OpOperand *potentialConsumer = nullptr;
2876  for (OpOperand &use : value.getUses()) {
2877  if (!isValueUsePotentialConsumer(use))
2878  continue;
2879 
2880  if (!potentialConsumer) {
2881  potentialConsumer = &use;
2882  continue;
2883  }
2884 
2885  InFlightDiagnostic diag = reportError()
2886  << " has more than one potential consumer";
2887  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2888  << "used here as operand #" << potentialConsumer->getOperandNumber();
2889  diag.attachNote(use.getOwner()->getLoc())
2890  << "used here as operand #" << use.getOperandNumber();
2891  return diag;
2892  }
2893 
2894  return success();
2895 }
2896 
2897 LogicalResult transform::SequenceOp::verify() {
2898  assert(getBodyBlock()->getNumArguments() >= 1 &&
2899  "the number of arguments must have been verified to be more than 1 by "
2900  "PossibleTopLevelTransformOpTrait");
2901 
2902  if (!getRoot() && !getExtraBindings().empty()) {
2903  return emitOpError()
2904  << "does not expect extra operands when used as top-level";
2905  }
2906 
2907  // Check if a block argument has more than one consuming use.
2908  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2909  if (failed(checkDoubleConsume(arg, [this, arg]() {
2910  return (emitOpError() << "block argument #" << arg.getArgNumber());
2911  }))) {
2912  return failure();
2913  }
2914  }
2915 
2916  // Check properties of the nested operations they cannot check themselves.
2917  for (Operation &child : *getBodyBlock()) {
2918  if (!isa<TransformOpInterface>(child) &&
2919  &child != &getBodyBlock()->back()) {
2921  emitOpError()
2922  << "expected children ops to implement TransformOpInterface";
2923  diag.attachNote(child.getLoc()) << "op without interface";
2924  return diag;
2925  }
2926 
2927  for (OpResult result : child.getResults()) {
2928  auto report = [&]() {
2929  return (child.emitError() << "result #" << result.getResultNumber());
2930  };
2931  if (failed(checkDoubleConsume(result, report)))
2932  return failure();
2933  }
2934  }
2935 
2936  if (!getBodyBlock()->mightHaveTerminator())
2937  return emitOpError() << "expects to have a terminator in the body";
2938 
2939  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2940  getOperation()->getResultTypes()) {
2941  InFlightDiagnostic diag = emitOpError()
2942  << "expects the types of the terminator operands "
2943  "to match the types of the result";
2944  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2945  return diag;
2946  }
2947  return success();
2948 }
2949 
2950 void transform::SequenceOp::getEffects(
2951  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2952  getPotentialTopLevelEffects(effects);
2953 }
2954 
2956 transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2957  assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2958  if (getOperation()->getNumOperands() > 0)
2959  return getOperation()->getOperands();
2960  return OperandRange(getOperation()->operand_end(),
2961  getOperation()->operand_end());
2962 }
2963 
2964 void transform::SequenceOp::getSuccessorRegions(
2965  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2966  if (point.isParent()) {
2967  Region *bodyRegion = &getBody();
2968  regions.emplace_back(bodyRegion, getNumOperands() != 0
2969  ? bodyRegion->getArguments()
2971  return;
2972  }
2973 
2974  assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2975  &getBody() &&
2976  "unexpected region index");
2977  regions.emplace_back(getOperation(), getOperation()->getResults());
2978 }
2979 
2980 void transform::SequenceOp::getRegionInvocationBounds(
2981  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2982  (void)operands;
2983  bounds.emplace_back(1, 1);
2984 }
2985 
2986 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2987  TypeRange resultTypes,
2988  FailurePropagationMode failurePropagationMode,
2989  Value root,
2990  SequenceBodyBuilderFn bodyBuilder) {
2991  build(builder, state, resultTypes, failurePropagationMode, root,
2992  /*extra_bindings=*/ValueRange());
2993  Type bbArgType = root.getType();
2994  buildSequenceBody(builder, state, bbArgType,
2995  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2996 }
2997 
2998 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2999  TypeRange resultTypes,
3000  FailurePropagationMode failurePropagationMode,
3001  Value root, ValueRange extraBindings,
3002  SequenceBodyBuilderArgsFn bodyBuilder) {
3003  build(builder, state, resultTypes, failurePropagationMode, root,
3004  extraBindings);
3005  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3006  bodyBuilder);
3007 }
3008 
3009 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3010  TypeRange resultTypes,
3011  FailurePropagationMode failurePropagationMode,
3012  Type bbArgType,
3013  SequenceBodyBuilderFn bodyBuilder) {
3014  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3015  /*extra_bindings=*/ValueRange());
3016  buildSequenceBody(builder, state, bbArgType,
3017  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3018 }
3019 
3020 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3021  TypeRange resultTypes,
3022  FailurePropagationMode failurePropagationMode,
3023  Type bbArgType, TypeRange extraBindingTypes,
3024  SequenceBodyBuilderArgsFn bodyBuilder) {
3025  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3026  /*extra_bindings=*/ValueRange());
3027  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3028 }
3029 
3030 //===----------------------------------------------------------------------===//
3031 // PrintOp
3032 //===----------------------------------------------------------------------===//
3033 
3034 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3035  StringRef name) {
3036  if (!name.empty())
3037  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3038 }
3039 
3040 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3041  Value target, StringRef name) {
3042  result.addOperands({target});
3043  build(builder, result, name);
3044 }
3045 
3047 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3048  transform::TransformResults &results,
3049  transform::TransformState &state) {
3050  llvm::outs() << "[[[ IR printer: ";
3051  if (getName().has_value())
3052  llvm::outs() << *getName() << " ";
3053 
3054  OpPrintingFlags printFlags;
3055  if (getAssumeVerified().value_or(false))
3056  printFlags.assumeVerified();
3057  if (getUseLocalScope().value_or(false))
3058  printFlags.useLocalScope();
3059  if (getSkipRegions().value_or(false))
3060  printFlags.skipRegions();
3061 
3062  if (!getTarget()) {
3063  llvm::outs() << "top-level ]]]\n";
3064  state.getTopLevel()->print(llvm::outs(), printFlags);
3065  llvm::outs() << "\n";
3066  llvm::outs().flush();
3068  }
3069 
3070  llvm::outs() << "]]]\n";
3071  for (Operation *target : state.getPayloadOps(getTarget())) {
3072  target->print(llvm::outs(), printFlags);
3073  llvm::outs() << "\n";
3074  }
3075 
3076  llvm::outs().flush();
3078 }
3079 
3080 void transform::PrintOp::getEffects(
3081  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3082  // We don't really care about mutability here, but `getTarget` now
3083  // unconditionally casts to a specific type before verification could run
3084  // here.
3085  if (!getTargetMutable().empty())
3086  onlyReadsHandle(getTargetMutable()[0], effects);
3087  onlyReadsPayload(effects);
3088 
3089  // There is no resource for stderr file descriptor, so just declare print
3090  // writes into the default resource.
3091  effects.emplace_back(MemoryEffects::Write::get());
3092 }
3093 
3094 //===----------------------------------------------------------------------===//
3095 // VerifyOp
3096 //===----------------------------------------------------------------------===//
3097 
3099 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3100  Operation *target,
3102  transform::TransformState &state) {
3103  if (failed(::mlir::verify(target))) {
3105  << "failed to verify payload op";
3106  diag.attachNote(target->getLoc()) << "payload op";
3107  return diag;
3108  }
3110 }
3111 
3112 void transform::VerifyOp::getEffects(
3113  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3114  transform::onlyReadsHandle(getTargetMutable(), effects);
3115 }
3116 
3117 //===----------------------------------------------------------------------===//
3118 // YieldOp
3119 //===----------------------------------------------------------------------===//
3120 
3121 void transform::YieldOp::getEffects(
3122  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3123  onlyReadsHandle(getOperandsMutable(), effects);
3124 }
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
Definition: DLTI.cpp:38
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:296
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:288
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:282
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:537
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:232
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
Definition: PassOptions.h:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.