MLIR  18.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dominance.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Verifier.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Pass/PassRegistry.h"
28 #include "mlir/Transforms/CSE.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/SmallPtrSet.h"
35 #include "llvm/Support/Debug.h"
36 #include <optional>
37 
38 #define DEBUG_TYPE "transform-dialect"
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
40 
41 #define DEBUG_TYPE_MATCHER "transform-matcher"
42 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
43 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
44 
45 using namespace mlir;
46 
48  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
49  Type &rootType,
50  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
51  SmallVectorImpl<Type> &extraBindingTypes);
52 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
53  Value root, Type rootType,
54  ValueRange extraBindings,
55  TypeRange extraBindingTypes);
56 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
57  ArrayAttr matchers, ArrayAttr actions);
59  ArrayAttr &matchers,
60  ArrayAttr &actions);
61 
62 /// Helper function to check if the given transform op is contained in (or
63 /// equal to) the given payload target op. In that case, an error is returned.
64 /// Transforming transform IR that is currently executing is generally unsafe.
66 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
67  Operation *payload) {
68  Operation *transformAncestor = transform.getOperation();
69  while (transformAncestor) {
70  if (transformAncestor == payload) {
72  transform.emitDefiniteFailure()
73  << "cannot apply transform to itself (or one of its ancestors)";
74  diag.attachNote(payload->getLoc()) << "target payload op";
75  return diag;
76  }
77  transformAncestor = transformAncestor->getParentOp();
78  }
80 }
81 
82 #define GET_OP_CLASSES
83 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
84 
85 //===----------------------------------------------------------------------===//
86 // AlternativesOp
87 //===----------------------------------------------------------------------===//
88 
90 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
91  if (!point.isParent() && getOperation()->getNumOperands() == 1)
92  return getOperation()->getOperands();
93  return OperandRange(getOperation()->operand_end(),
94  getOperation()->operand_end());
95 }
96 
97 void transform::AlternativesOp::getSuccessorRegions(
98  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
99  for (Region &alternative : llvm::drop_begin(
100  getAlternatives(),
101  point.isParent() ? 0
102  : point.getRegionOrNull()->getRegionNumber() + 1)) {
103  regions.emplace_back(&alternative, !getOperands().empty()
104  ? alternative.getArguments()
106  }
107  if (!point.isParent())
108  regions.emplace_back(getOperation()->getResults());
109 }
110 
111 void transform::AlternativesOp::getRegionInvocationBounds(
112  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
113  (void)operands;
114  // The region corresponding to the first alternative is always executed, the
115  // remaining may or may not be executed.
116  bounds.reserve(getNumRegions());
117  bounds.emplace_back(1, 1);
118  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
119 }
120 
122  transform::TransformResults &results) {
123  for (const auto &res : block->getParentOp()->getOpResults())
124  results.set(res, {});
125 }
126 
128 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
130  transform::TransformState &state) {
131  SmallVector<Operation *> originals;
132  if (Value scopeHandle = getScope())
133  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
134  else
135  originals.push_back(state.getTopLevel());
136 
137  for (Operation *original : originals) {
138  if (original->isAncestor(getOperation())) {
139  auto diag = emitDefiniteFailure()
140  << "scope must not contain the transforms being applied";
141  diag.attachNote(original->getLoc()) << "scope";
142  return diag;
143  }
144  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
145  auto diag = emitDefiniteFailure()
146  << "only isolated-from-above ops can be alternative scopes";
147  diag.attachNote(original->getLoc()) << "scope";
148  return diag;
149  }
150  }
151 
152  for (Region &reg : getAlternatives()) {
153  // Clone the scope operations and make the transforms in this alternative
154  // region apply to them by virtue of mapping the block argument (the only
155  // visible handle) to the cloned scope operations. This effectively prevents
156  // the transformation from accessing any IR outside the scope.
157  auto scope = state.make_region_scope(reg);
158  auto clones = llvm::to_vector(
159  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
160  auto deleteClones = llvm::make_scope_exit([&] {
161  for (Operation *clone : clones)
162  clone->erase();
163  });
164  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
166 
167  bool failed = false;
168  for (Operation &transform : reg.front().without_terminator()) {
170  state.applyTransform(cast<TransformOpInterface>(transform));
171  if (result.isSilenceableFailure()) {
172  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
173  << "\n");
174  failed = true;
175  break;
176  }
177 
178  if (::mlir::failed(result.silence()))
180  }
181 
182  // If all operations in the given alternative succeeded, no need to consider
183  // the rest. Replace the original scoping operation with the clone on which
184  // the transformations were performed.
185  if (!failed) {
186  // We will be using the clones, so cancel their scheduled deletion.
187  deleteClones.release();
188  TrackingListener listener(state, *this);
189  IRRewriter rewriter(getContext(), &listener);
190  for (const auto &kvp : llvm::zip(originals, clones)) {
191  Operation *original = std::get<0>(kvp);
192  Operation *clone = std::get<1>(kvp);
193  original->getBlock()->getOperations().insert(original->getIterator(),
194  clone);
195  rewriter.replaceOp(original, clone->getResults());
196  }
197  detail::forwardTerminatorOperands(&reg.front(), state, results);
199  }
200  }
201  return emitSilenceableError() << "all alternatives failed";
202 }
203 
204 void transform::AlternativesOp::getEffects(
205  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
206  consumesHandle(getOperands(), effects);
207  producesHandle(getResults(), effects);
208  for (Region *region : getRegions()) {
209  if (!region->empty())
210  producesHandle(region->front().getArguments(), effects);
211  }
212  modifiesPayload(effects);
213 }
214 
216  for (Region &alternative : getAlternatives()) {
217  Block &block = alternative.front();
218  Operation *terminator = block.getTerminator();
219  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
220  InFlightDiagnostic diag = emitOpError()
221  << "expects terminator operands to have the "
222  "same type as results of the operation";
223  diag.attachNote(terminator->getLoc()) << "terminator";
224  return diag;
225  }
226  }
227 
228  return success();
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // AnnotateOp
233 //===----------------------------------------------------------------------===//
234 
236 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
238  transform::TransformState &state) {
239  SmallVector<Operation *> targets =
240  llvm::to_vector(state.getPayloadOps(getTarget()));
241 
243  if (auto paramH = getParam()) {
244  ArrayRef<Attribute> params = state.getParams(paramH);
245  if (params.size() != 1) {
246  if (targets.size() != params.size()) {
247  return emitSilenceableError()
248  << "parameter and target have different payload lengths ("
249  << params.size() << " vs " << targets.size() << ")";
250  }
251  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
252  target->setAttr(getName(), attr);
254  }
255  attr = params[0];
256  }
257  for (auto target : targets)
258  target->setAttr(getName(), attr);
260 }
261 
262 void transform::AnnotateOp::getEffects(
263  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
264  onlyReadsHandle(getTarget(), effects);
265  onlyReadsHandle(getParam(), effects);
266  modifiesPayload(effects);
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // ApplyCommonSubexpressionEliminationOp
271 //===----------------------------------------------------------------------===//
272 
274 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
275  transform::TransformRewriter &rewriter, Operation *target,
276  ApplyToEachResultList &results, transform::TransformState &state) {
277  // Make sure that this transform is not applied to itself. Modifying the
278  // transform IR while it is being interpreted is generally dangerous.
279  DiagnosedSilenceableFailure payloadCheck =
281  if (!payloadCheck.succeeded())
282  return payloadCheck;
283 
284  DominanceInfo domInfo;
285  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
287 }
288 
289 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
290  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
291  transform::onlyReadsHandle(getTarget(), effects);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // ApplyDeadCodeEliminationOp
297 //===----------------------------------------------------------------------===//
298 
299 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
300  transform::TransformRewriter &rewriter, Operation *target,
301  ApplyToEachResultList &results, transform::TransformState &state) {
302  // Make sure that this transform is not applied to itself. Modifying the
303  // transform IR while it is being interpreted is generally dangerous.
304  DiagnosedSilenceableFailure payloadCheck =
306  if (!payloadCheck.succeeded())
307  return payloadCheck;
308 
309  // Maintain a worklist of potentially dead ops.
310  SetVector<Operation *> worklist;
311 
312  // Helper function that adds all defining ops of used values (operands and
313  // operands of nested ops).
314  auto addDefiningOpsToWorklist = [&](Operation *op) {
315  op->walk([&](Operation *op) {
316  for (Value v : op->getOperands())
317  if (Operation *defOp = v.getDefiningOp())
318  if (target->isProperAncestor(defOp))
319  worklist.insert(defOp);
320  });
321  };
322 
323  // Helper function that erases an op.
324  auto eraseOp = [&](Operation *op) {
325  // Remove op and nested ops from the worklist.
326  op->walk([&](Operation *op) {
327  auto it = llvm::find(worklist, op);
328  if (it != worklist.end())
329  worklist.erase(it);
330  });
331  rewriter.eraseOp(op);
332  };
333 
334  // Initial walk over the IR.
335  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
336  if (op != target && isOpTriviallyDead(op)) {
337  addDefiningOpsToWorklist(op);
338  eraseOp(op);
339  }
340  });
341 
342  // Erase all ops that have become dead.
343  while (!worklist.empty()) {
344  Operation *op = worklist.pop_back_val();
345  if (!isOpTriviallyDead(op))
346  continue;
347  addDefiningOpsToWorklist(op);
348  eraseOp(op);
349  }
350 
352 }
353 
354 void transform::ApplyDeadCodeEliminationOp::getEffects(
355  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
356  transform::onlyReadsHandle(getTarget(), effects);
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ApplyPatternsOp
362 //===----------------------------------------------------------------------===//
363 
364 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
365  transform::TransformRewriter &rewriter, Operation *target,
366  ApplyToEachResultList &results, transform::TransformState &state) {
367  // Make sure that this transform is not applied to itself. Modifying the
368  // transform IR while it is being interpreted is generally dangerous. Even
369  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
370  // performs many additional simplifications such as dead code elimination.
371  DiagnosedSilenceableFailure payloadCheck =
373  if (!payloadCheck.succeeded())
374  return payloadCheck;
375 
376  // Gather all specified patterns.
377  MLIRContext *ctx = target->getContext();
378  RewritePatternSet patterns(ctx);
379  if (!getRegion().empty()) {
380  for (Operation &op : getRegion().front()) {
381  cast<transform::PatternDescriptorOpInterface>(&op)
382  .populatePatternsWithState(patterns, state);
383  }
384  }
385 
386  // Configure the GreedyPatternRewriteDriver.
387  GreedyRewriteConfig config;
388  config.listener =
389  static_cast<RewriterBase::Listener *>(rewriter.getListener());
390  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
391 
392  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
393  // was requested, apply the greedy pattern rewrite only once. (The greedy
394  // pattern rewrite driver already iterates to a fixpoint internally.)
395  bool cseChanged = false;
396  // One or two iterations should be sufficient. Stop iterating after a certain
397  // threshold to make debugging easier.
398  static const int64_t kNumMaxIterations = 50;
399  int64_t iteration = 0;
400  do {
401  LogicalResult result = failure();
402  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
403  // Op is isolated from above. Apply patterns and also perform region
404  // simplification.
405  result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
406  } else {
407  // Manually gather list of ops because the other
408  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
409  // from above. This way, patterns can be applied to ops that are not
410  // isolated from above. Regions are not being simplified. Furthermore,
411  // only a single greedy rewrite iteration is performed.
413  target->walk([&](Operation *nestedOp) {
414  if (target != nestedOp)
415  ops.push_back(nestedOp);
416  });
417  result = applyOpPatternsAndFold(ops, frozenPatterns, config);
418  }
419 
420  // A failure typically indicates that the pattern application did not
421  // converge.
422  if (failed(result)) {
423  return emitSilenceableFailure(target)
424  << "greedy pattern application failed";
425  }
426 
427  if (getApplyCse()) {
428  DominanceInfo domInfo;
429  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
430  &cseChanged);
431  }
432  } while (cseChanged && ++iteration < kNumMaxIterations);
433 
434  if (iteration == kNumMaxIterations)
435  return emitDefiniteFailure() << "fixpoint iteration did not converge";
436 
438 }
439 
441  if (!getRegion().empty()) {
442  for (Operation &op : getRegion().front()) {
443  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
444  InFlightDiagnostic diag = emitOpError()
445  << "expected children ops to implement "
446  "PatternDescriptorOpInterface";
447  diag.attachNote(op.getLoc()) << "op without interface";
448  return diag;
449  }
450  }
451  }
452  return success();
453 }
454 
455 void transform::ApplyPatternsOp::getEffects(
456  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
457  transform::onlyReadsHandle(getTarget(), effects);
459 }
460 
461 void transform::ApplyPatternsOp::build(
462  OpBuilder &builder, OperationState &result, Value target,
463  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
464  result.addOperands(target);
465 
466  OpBuilder::InsertionGuard g(builder);
467  Region *region = result.addRegion();
468  builder.createBlock(region);
469  if (bodyBuilder)
470  bodyBuilder(builder, result.location);
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // ApplyCanonicalizationPatternsOp
475 //===----------------------------------------------------------------------===//
476 
477 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
478  RewritePatternSet &patterns) {
479  MLIRContext *ctx = patterns.getContext();
480  for (Dialect *dialect : ctx->getLoadedDialects())
481  dialect->getCanonicalizationPatterns(patterns);
483  op.getCanonicalizationPatterns(patterns, ctx);
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // ApplyConversionPatternsOp
488 //===----------------------------------------------------------------------===//
489 
490 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
493  MLIRContext *ctx = getContext();
494 
495  // Instantiate the default type converter if a type converter builder is
496  // specified.
497  std::unique_ptr<TypeConverter> defaultTypeConverter;
498  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
499  getDefaultTypeConverter();
500  if (typeConverterBuilder)
501  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
502 
503  // Configure conversion target.
504  ConversionTarget conversionTarget(*getContext());
505  if (getLegalOps())
506  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
507  conversionTarget.addLegalOp(
508  OperationName(cast<StringAttr>(attr).getValue(), ctx));
509  if (getIllegalOps())
510  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
511  conversionTarget.addIllegalOp(
512  OperationName(cast<StringAttr>(attr).getValue(), ctx));
513  if (getLegalDialects())
514  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
515  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
516  if (getIllegalDialects())
517  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
518  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
519 
520  // Gather all specified patterns.
521  RewritePatternSet patterns(ctx);
522  // Need to keep the converters alive until after pattern application because
523  // the patterns take a reference to an object that would otherwise get out of
524  // scope.
525  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
526  if (!getPatterns().empty()) {
527  for (Operation &op : getPatterns().front()) {
528  auto descriptor =
529  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
530 
531  // Check if this pattern set specifies a type converter.
532  std::unique_ptr<TypeConverter> typeConverter =
533  descriptor.getTypeConverter();
534  TypeConverter *converter = nullptr;
535  if (typeConverter) {
536  keepAliveConverters.emplace_back(std::move(typeConverter));
537  converter = keepAliveConverters.back().get();
538  } else {
539  // No type converter specified: Use the default type converter.
540  if (!defaultTypeConverter) {
541  auto diag = emitDefiniteFailure()
542  << "pattern descriptor does not specify type "
543  "converter and apply_conversion_patterns op has "
544  "no default type converter";
545  diag.attachNote(op.getLoc()) << "pattern descriptor op";
546  return diag;
547  }
548  converter = defaultTypeConverter.get();
549  }
550 
551  // Add descriptor-specific updates to the conversion target, which may
552  // depend on the final type converter. In structural converters, the
553  // legality of types dictates the dynamic legality of an operation.
554  descriptor.populateConversionTargetRules(*converter, conversionTarget);
555 
556  descriptor.populatePatterns(*converter, patterns);
557  }
558  }
559 
560  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
561  for (Operation *target : state.getPayloadOps(getTarget())) {
562  // Make sure that this transform is not applied to itself. Modifying the
563  // transform IR while it is being interpreted is generally dangerous.
564  DiagnosedSilenceableFailure payloadCheck =
566  if (!payloadCheck.succeeded())
567  return payloadCheck;
568 
569  LogicalResult status = failure();
570  if (getPartialConversion()) {
571  status = applyPartialConversion(target, conversionTarget, frozenPatterns);
572  } else {
573  status = applyFullConversion(target, conversionTarget, frozenPatterns);
574  }
575 
576  if (failed(status)) {
577  auto diag = emitSilenceableError() << "dialect conversion failed";
578  diag.attachNote(target->getLoc()) << "target op";
579  return diag;
580  }
581  }
582 
584 }
585 
587  if (getNumRegions() != 1 && getNumRegions() != 2)
588  return emitOpError() << "expected 1 or 2 regions";
589  if (!getPatterns().empty()) {
590  for (Operation &op : getPatterns().front()) {
591  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
593  emitOpError() << "expected pattern children ops to implement "
594  "ConversionPatternDescriptorOpInterface";
595  diag.attachNote(op.getLoc()) << "op without interface";
596  return diag;
597  }
598  }
599  }
600  if (getNumRegions() == 2) {
601  Region &typeConverterRegion = getRegion(1);
602  if (!llvm::hasSingleElement(typeConverterRegion.front()))
603  return emitOpError()
604  << "expected exactly one op in default type converter region";
605  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
606  &typeConverterRegion.front().front());
607  if (!typeConverterOp) {
608  InFlightDiagnostic diag = emitOpError()
609  << "expected default converter child op to "
610  "implement TypeConverterBuilderOpInterface";
611  diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
612  return diag;
613  }
614  // Check default type converter type.
615  if (!getPatterns().empty()) {
616  for (Operation &op : getPatterns().front()) {
617  auto descriptor =
618  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
619  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
620  return failure();
621  }
622  }
623  }
624  return success();
625 }
626 
627 void transform::ApplyConversionPatternsOp::getEffects(
628  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
629  transform::consumesHandle(getTarget(), effects);
631 }
632 
633 void transform::ApplyConversionPatternsOp::build(
634  OpBuilder &builder, OperationState &result, Value target,
635  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
636  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
637  result.addOperands(target);
638 
639  {
640  OpBuilder::InsertionGuard g(builder);
641  Region *region1 = result.addRegion();
642  builder.createBlock(region1);
643  if (patternsBodyBuilder)
644  patternsBodyBuilder(builder, result.location);
645  }
646  {
647  OpBuilder::InsertionGuard g(builder);
648  Region *region2 = result.addRegion();
649  builder.createBlock(region2);
650  if (typeConverterBodyBuilder)
651  typeConverterBodyBuilder(builder, result.location);
652  }
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // ApplyToLLVMConversionPatternsOp
657 //===----------------------------------------------------------------------===//
658 
659 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
660  TypeConverter &typeConverter, RewritePatternSet &patterns) {
661  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
662  assert(dialect && "expected that dialect is loaded");
663  auto iface = cast<ConvertToLLVMPatternInterface>(dialect);
664  // ConversionTarget is currently ignored because the enclosing
665  // apply_conversion_patterns op sets up its own ConversionTarget.
666  ConversionTarget target(*getContext());
667  iface->populateConvertToLLVMConversionPatterns(
668  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
669 }
670 
671 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
672  transform::TypeConverterBuilderOpInterface builder) {
673  if (builder.getTypeConverterType() != "LLVMTypeConverter")
674  return emitOpError("expected LLVMTypeConverter");
675  return success();
676 }
677 
679  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
680  if (!dialect)
681  return emitOpError("unknown dialect or dialect not loaded: ")
682  << getDialectName();
683  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
684  if (!iface)
685  return emitOpError(
686  "dialect does not implement ConvertToLLVMPatternInterface or "
687  "extension was not loaded: ")
688  << getDialectName();
689  return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // ApplyLoopInvariantCodeMotionOp
694 //===----------------------------------------------------------------------===//
695 
697 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
698  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
700  transform::TransformState &state) {
701  // Currently, LICM does not remove operations, so we don't need tracking.
702  // If this ever changes, add a LICM entry point that takes a rewriter.
703  moveLoopInvariantCode(target);
705 }
706 
707 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
708  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
709  transform::onlyReadsHandle(getTarget(), effects);
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // ApplyRegisteredPassOp
715 //===----------------------------------------------------------------------===//
716 
717 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
718  transform::TransformRewriter &rewriter, Operation *target,
719  ApplyToEachResultList &results, transform::TransformState &state) {
720  // Make sure that this transform is not applied to itself. Modifying the
721  // transform IR while it is being interpreted is generally dangerous. Even
722  // more so when applying passes because they may perform a wide range of IR
723  // modifications.
724  DiagnosedSilenceableFailure payloadCheck =
726  if (!payloadCheck.succeeded())
727  return payloadCheck;
728 
729  // Get pass or pass pipeline from registry.
730  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
731  if (!info)
732  info = PassInfo::lookup(getPassName());
733  if (!info)
734  return emitDefiniteFailure()
735  << "unknown pass or pass pipeline: " << getPassName();
736 
737  // Create pass manager and run the pass or pass pipeline.
738  PassManager pm(getContext());
739  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
740  emitError(msg);
741  return failure();
742  }))) {
743  return emitDefiniteFailure()
744  << "failed to add pass or pass pipeline to pipeline: "
745  << getPassName();
746  }
747  if (failed(pm.run(target))) {
748  auto diag = emitSilenceableError() << "pass pipeline failed";
749  diag.attachNote(target->getLoc()) << "target op";
750  return diag;
751  }
752 
753  results.push_back(target);
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // CastOp
759 //===----------------------------------------------------------------------===//
760 
762 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
763  Operation *target, ApplyToEachResultList &results,
764  transform::TransformState &state) {
765  results.push_back(target);
767 }
768 
769 void transform::CastOp::getEffects(
770  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
771  onlyReadsPayload(effects);
772  onlyReadsHandle(getInput(), effects);
773  producesHandle(getOutput(), effects);
774 }
775 
776 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
777  assert(inputs.size() == 1 && "expected one input");
778  assert(outputs.size() == 1 && "expected one output");
779  return llvm::all_of(
780  std::initializer_list<Type>{inputs.front(), outputs.front()},
781  [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // ForeachMatchOp
786 //===----------------------------------------------------------------------===//
787 
788 /// Applies matcher operations from the given `block` assigning `op` as the
789 /// payload of the block's first argument. Updates `state` accordingly. If any
790 /// of the matcher produces a silenceable failure, discards it (printing the
791 /// content to the debug output stream) and returns failure. If any of the
792 /// matchers produces a definite failure, reports it and returns failure. If all
793 /// matchers in the block succeed, populates `mappings` with the payload
794 /// entities associated with the block terminator operands.
797  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
798  assert(block.getParent() && "cannot match using a detached block");
799  auto matchScope = state.make_region_scope(*block.getParent());
800  if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
802 
803  for (Operation &match : block.without_terminator()) {
804  if (!isa<transform::MatchOpInterface>(match)) {
805  return emitDefiniteFailure(match.getLoc())
806  << "expected operations in the match part to "
807  "implement MatchOpInterface";
808  }
810  state.applyTransform(cast<transform::TransformOpInterface>(match));
811  if (diag.succeeded())
812  continue;
813 
814  return diag;
815  }
816 
817  // Remember the values mapped to the terminator operands so we can
818  // forward them to the action.
819  ValueRange yieldedValues = block.getTerminator()->getOperands();
820  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
822 }
823 
825 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
827  transform::TransformState &state) {
829  matchActionPairs;
830  matchActionPairs.reserve(getMatchers().size());
831  SymbolTableCollection symbolTable;
832  for (auto &&[matcher, action] :
833  llvm::zip_equal(getMatchers(), getActions())) {
834  auto matcherSymbol =
835  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
836  getOperation(), cast<SymbolRefAttr>(matcher));
837  auto actionSymbol =
838  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
839  getOperation(), cast<SymbolRefAttr>(action));
840  assert(matcherSymbol && actionSymbol &&
841  "unresolved symbols not caught by the verifier");
842 
843  if (matcherSymbol.isExternal())
844  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
845  if (actionSymbol.isExternal())
846  return emitDefiniteFailure() << "unresolved external symbol " << action;
847 
848  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
849  }
850 
851  for (Operation *root : state.getPayloadOps(getRoot())) {
852  WalkResult walkResult = root->walk([&](Operation *op) {
853  // If getRestrictRoot is not present, skip over the root op itself so we
854  // don't invalidate it.
855  if (!getRestrictRoot() && op == root)
856  return WalkResult::advance();
857 
858  DEBUG_MATCHER({
859  DBGS_MATCHER() << "matching ";
860  op->print(llvm::dbgs(),
861  OpPrintingFlags().assumeVerified().skipRegions());
862  llvm::dbgs() << " @" << op << "\n";
863  });
864 
865  // Try all the match/action pairs until the first successful match.
866  for (auto [matcher, action] : matchActionPairs) {
869  matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
870  if (diag.isDefiniteFailure())
871  return WalkResult::interrupt();
872  if (diag.isSilenceableFailure()) {
873  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
874  << " failed: " << diag.getMessage());
875  continue;
876  }
877 
878  auto scope = state.make_region_scope(action.getFunctionBody());
879  for (auto &&[arg, map] : llvm::zip_equal(
880  action.getFunctionBody().front().getArguments(), mappings)) {
881  if (failed(state.mapBlockArgument(arg, map)))
882  return WalkResult::interrupt();
883  }
884 
885  for (Operation &transform :
886  action.getFunctionBody().front().without_terminator()) {
888  state.applyTransform(cast<TransformOpInterface>(transform));
889  if (failed(result.checkAndReport()))
890  return WalkResult::interrupt();
891  }
892  break;
893  }
894  return WalkResult::advance();
895  });
896  if (walkResult.wasInterrupted())
898  }
899 
900  // The root operation should not have been affected, so we can just reassign
901  // the payload to the result. Note that we need to consume the root handle to
902  // make sure any handles to operations inside, that could have been affected
903  // by actions, are invalidated.
904  results.set(llvm::cast<OpResult>(getUpdated()),
905  state.getPayloadOps(getRoot()));
907 }
908 
909 void transform::ForeachMatchOp::getEffects(
910  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
911  // Bail if invalid.
912  if (getOperation()->getNumOperands() < 1 ||
913  getOperation()->getNumResults() < 1) {
914  return modifiesPayload(effects);
915  }
916 
917  consumesHandle(getRoot(), effects);
918  producesHandle(getUpdated(), effects);
919  modifiesPayload(effects);
920 }
921 
922 /// Parses the comma-separated list of symbol reference pairs of the format
923 /// `@matcher -> @action`.
925  ArrayAttr &matchers,
926  ArrayAttr &actions) {
927  StringAttr matcher;
928  StringAttr action;
929  SmallVector<Attribute> matcherList;
930  SmallVector<Attribute> actionList;
931  do {
932  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
933  parser.parseSymbolName(action)) {
934  return failure();
935  }
936  matcherList.push_back(SymbolRefAttr::get(matcher));
937  actionList.push_back(SymbolRefAttr::get(action));
938  } while (parser.parseOptionalComma().succeeded());
939 
940  matchers = parser.getBuilder().getArrayAttr(matcherList);
941  actions = parser.getBuilder().getArrayAttr(actionList);
942  return success();
943 }
944 
945 /// Prints the comma-separated list of symbol reference pairs of the format
946 /// `@matcher -> @action`.
948  ArrayAttr matchers, ArrayAttr actions) {
949  printer.increaseIndent();
950  printer.increaseIndent();
951  for (auto &&[matcher, action, idx] : llvm::zip_equal(
952  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
953  printer.printNewline();
954  printer << cast<SymbolRefAttr>(matcher) << " -> "
955  << cast<SymbolRefAttr>(action);
956  if (idx != matchers.size() - 1)
957  printer << ", ";
958  }
959  printer.decreaseIndent();
960  printer.decreaseIndent();
961 }
962 
964  if (getMatchers().size() != getActions().size())
965  return emitOpError() << "expected the same number of matchers and actions";
966  if (getMatchers().empty())
967  return emitOpError() << "expected at least one match/action pair";
968 
969  llvm::SmallPtrSet<Attribute, 8> matcherNames;
970  for (Attribute name : getMatchers()) {
971  if (matcherNames.insert(name).second)
972  continue;
973  emitWarning() << "matcher " << name
974  << " is used more than once, only the first match will apply";
975  }
976 
977  return success();
978 }
979 
980 /// Returns `true` if both types implement one of the interfaces provided as
981 /// template parameters.
982 template <typename... Tys>
983 static bool implementSameInterface(Type t1, Type t2) {
984  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
985 }
986 
987 /// Returns `true` if both types implement one of the transform dialect
988 /// interfaces.
990  return implementSameInterface<transform::TransformHandleTypeInterface,
991  transform::TransformParamTypeInterface,
992  transform::TransformValueHandleTypeInterface>(
993  t1, t2);
994 }
995 
996 /// Checks that the attributes of the function-like operation have correct
997 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
998 /// annotations being present even if they can be inferred from the body.
1000 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1001  bool alsoVerifyInternal = false) {
1002  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1003  llvm::SmallDenseSet<unsigned> consumedArguments;
1004  if (!op.isExternal()) {
1005  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1006  consumedArguments);
1007  }
1008  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1009  bool isConsumed =
1010  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1011  nullptr;
1012  bool isReadOnly =
1013  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1014  nullptr;
1015  if (isConsumed && isReadOnly) {
1016  return transformOp.emitSilenceableError()
1017  << "argument #" << i << " cannot be both readonly and consumed";
1018  }
1019  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1020  return transformOp.emitSilenceableError()
1021  << "must provide consumed/readonly status for arguments of "
1022  "external or called ops";
1023  }
1024  if (op.isExternal())
1025  continue;
1026 
1027  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1028  return transformOp.emitSilenceableError()
1029  << "argument #" << i
1030  << " is consumed in the body but is not marked as such";
1031  }
1032  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1033  // Cannot use op.emitWarning() here as it would attempt to verify the op
1034  // before printing, resulting in infinite recursion.
1035  emitWarning(op->getLoc())
1036  << "op argument #" << i
1037  << " is not consumed in the body but is marked as consumed";
1038  }
1039  }
1041 }
1042 
1043 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1044  SymbolTableCollection &symbolTable) {
1045  assert(getMatchers().size() == getActions().size());
1046  auto consumedAttr =
1047  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1048  for (auto &&[matcher, action] :
1049  llvm::zip_equal(getMatchers(), getActions())) {
1050  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1051  symbolTable.lookupNearestSymbolFrom(getOperation(),
1052  cast<SymbolRefAttr>(matcher)));
1053  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1054  symbolTable.lookupNearestSymbolFrom(getOperation(),
1055  cast<SymbolRefAttr>(action)));
1056  if (!matcherSymbol ||
1057  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1058  return emitError() << "unresolved matcher symbol " << matcher;
1059  if (!actionSymbol ||
1060  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1061  return emitError() << "unresolved action symbol " << action;
1062 
1063  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1064  /*emitWarnings=*/false,
1065  /*alsoVerifyInternal=*/true)
1066  .checkAndReport())) {
1067  return failure();
1068  }
1070  /*emitWarnings=*/false,
1071  /*alsoVerifyInternal=*/true)
1072  .checkAndReport())) {
1073  return failure();
1074  }
1075 
1076  ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1077  ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1078  if (matcherResults.size() != actionArguments.size()) {
1079  return emitError() << "mismatching number of matcher results and "
1080  "action arguments between "
1081  << matcher << " (" << matcherResults.size() << ") and "
1082  << action << " (" << actionArguments.size() << ")";
1083  }
1084  for (auto &&[i, matcherType, actionType] :
1085  llvm::enumerate(matcherResults, actionArguments)) {
1086  if (implementSameTransformInterface(matcherType, actionType))
1087  continue;
1088 
1089  return emitError() << "mismatching type interfaces for matcher result "
1090  "and action argument #"
1091  << i;
1092  }
1093 
1094  if (!actionSymbol.getResultTypes().empty()) {
1096  emitError() << "action symbol is not expected to have results";
1097  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1098  return diag;
1099  }
1100 
1101  if (matcherSymbol.getArgumentTypes().size() != 1 ||
1102  !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1103  getRoot().getType())) {
1105  emitOpError() << "expects matcher symbol to have one argument with "
1106  "the same transform interface as the first operand";
1107  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1108  return diag;
1109  }
1110 
1111  if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1113  emitOpError()
1114  << "does not expect matcher symbol to consume its operand";
1115  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1116  return diag;
1117  }
1118  }
1119  return success();
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // ForeachOp
1124 //===----------------------------------------------------------------------===//
1125 
1127 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1128  transform::TransformResults &results,
1129  transform::TransformState &state) {
1130  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1131  // Store payload ops in a vector because ops may be removed from the mapping
1132  // by the TrackingRewriter while the iteration is in progress.
1133  SmallVector<Operation *> targets =
1134  llvm::to_vector(state.getPayloadOps(getTarget()));
1135  for (Operation *op : targets) {
1136  auto scope = state.make_region_scope(getBody());
1137  if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
1139 
1140  // Execute loop body.
1141  for (Operation &transform : getBody().front().without_terminator()) {
1142  DiagnosedSilenceableFailure result = state.applyTransform(
1143  cast<transform::TransformOpInterface>(transform));
1144  if (!result.succeeded())
1145  return result;
1146  }
1147 
1148  // Append yielded payload ops to result list (if any).
1149  for (unsigned i = 0; i < getNumResults(); ++i) {
1150  auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1151  resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1152  }
1153  }
1154 
1155  for (unsigned i = 0; i < getNumResults(); ++i)
1156  results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1157 
1159 }
1160 
1161 void transform::ForeachOp::getEffects(
1162  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1163  BlockArgument iterVar = getIterationVariable();
1164  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1165  return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
1166  })) {
1167  consumesHandle(getTarget(), effects);
1168  } else {
1169  onlyReadsHandle(getTarget(), effects);
1170  }
1171 
1172  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1173  return doesModifyPayload(cast<TransformOpInterface>(&op));
1174  })) {
1175  modifiesPayload(effects);
1176  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1177  return doesReadPayload(cast<TransformOpInterface>(&op));
1178  })) {
1179  onlyReadsPayload(effects);
1180  }
1181 
1182  for (Value result : getResults())
1183  producesHandle(result, effects);
1184 }
1185 
1186 void transform::ForeachOp::getSuccessorRegions(
1187  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1188  Region *bodyRegion = &getBody();
1189  if (point.isParent()) {
1190  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1191  return;
1192  }
1193 
1194  // Branch back to the region or the parent.
1195  assert(point == getBody() && "unexpected region index");
1196  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1197  regions.emplace_back();
1198 }
1199 
1201 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1202  // The iteration variable op handle is mapped to a subset (one op to be
1203  // precise) of the payload ops of the ForeachOp operand.
1204  assert(point == getBody() && "unexpected region index");
1205  return getOperation()->getOperands();
1206 }
1207 
1208 transform::YieldOp transform::ForeachOp::getYieldOp() {
1209  return cast<transform::YieldOp>(getBody().front().getTerminator());
1210 }
1211 
1213  auto yieldOp = getYieldOp();
1214  if (getNumResults() != yieldOp.getNumOperands())
1215  return emitOpError() << "expects the same number of results as the "
1216  "terminator has operands";
1217  for (Value v : yieldOp.getOperands())
1218  if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1219  return yieldOp->emitOpError("expects operands to have types implementing "
1220  "TransformHandleTypeInterface");
1221  return success();
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // GetParentOp
1226 //===----------------------------------------------------------------------===//
1227 
1229 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1230  transform::TransformResults &results,
1231  transform::TransformState &state) {
1232  SmallVector<Operation *> parents;
1233  DenseSet<Operation *> resultSet;
1234  for (Operation *target : state.getPayloadOps(getTarget())) {
1235  Operation *parent = target;
1236  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1237  parent = parent->getParentOp();
1238  while (parent) {
1239  bool checkIsolatedFromAbove =
1240  !getIsolatedFromAbove() ||
1242  bool checkOpName = !getOpName().has_value() ||
1243  parent->getName().getStringRef() == *getOpName();
1244  if (checkIsolatedFromAbove && checkOpName)
1245  break;
1246  parent = parent->getParentOp();
1247  }
1248  if (!parent) {
1249  if (getAllowEmptyResults()) {
1250  results.set(llvm::cast<OpResult>(getResult()), parents);
1252  }
1254  emitSilenceableError()
1255  << "could not find a parent op that matches all requirements";
1256  diag.attachNote(target->getLoc()) << "target op";
1257  return diag;
1258  }
1259  }
1260  if (getDeduplicate()) {
1261  if (!resultSet.contains(parent)) {
1262  parents.push_back(parent);
1263  resultSet.insert(parent);
1264  }
1265  } else {
1266  parents.push_back(parent);
1267  }
1268  }
1269  results.set(llvm::cast<OpResult>(getResult()), parents);
1271 }
1272 
1273 //===----------------------------------------------------------------------===//
1274 // GetConsumersOfResult
1275 //===----------------------------------------------------------------------===//
1276 
1278 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1279  transform::TransformResults &results,
1280  transform::TransformState &state) {
1281  int64_t resultNumber = getResultNumber();
1282  auto payloadOps = state.getPayloadOps(getTarget());
1283  if (std::empty(payloadOps)) {
1284  results.set(cast<OpResult>(getResult()), {});
1286  }
1287  if (!llvm::hasSingleElement(payloadOps))
1288  return emitDefiniteFailure()
1289  << "handle must be mapped to exactly one payload op";
1290 
1291  Operation *target = *payloadOps.begin();
1292  if (target->getNumResults() <= resultNumber)
1293  return emitDefiniteFailure() << "result number overflow";
1294  results.set(llvm::cast<OpResult>(getResult()),
1295  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // GetDefiningOp
1301 //===----------------------------------------------------------------------===//
1302 
1304 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1305  transform::TransformResults &results,
1306  transform::TransformState &state) {
1307  SmallVector<Operation *> definingOps;
1308  for (Value v : state.getPayloadValues(getTarget())) {
1309  if (llvm::isa<BlockArgument>(v)) {
1311  emitSilenceableError() << "cannot get defining op of block argument";
1312  diag.attachNote(v.getLoc()) << "target value";
1313  return diag;
1314  }
1315  definingOps.push_back(v.getDefiningOp());
1316  }
1317  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // GetProducerOfOperand
1323 //===----------------------------------------------------------------------===//
1324 
1326 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1327  transform::TransformResults &results,
1328  transform::TransformState &state) {
1329  int64_t operandNumber = getOperandNumber();
1330  SmallVector<Operation *> producers;
1331  for (Operation *target : state.getPayloadOps(getTarget())) {
1332  Operation *producer =
1333  target->getNumOperands() <= operandNumber
1334  ? nullptr
1335  : target->getOperand(operandNumber).getDefiningOp();
1336  if (!producer) {
1338  emitSilenceableError()
1339  << "could not find a producer for operand number: " << operandNumber
1340  << " of " << *target;
1341  diag.attachNote(target->getLoc()) << "target op";
1342  return diag;
1343  }
1344  producers.push_back(producer);
1345  }
1346  results.set(llvm::cast<OpResult>(getResult()), producers);
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // GetResultOp
1352 //===----------------------------------------------------------------------===//
1353 
1355 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1356  transform::TransformResults &results,
1357  transform::TransformState &state) {
1358  int64_t resultNumber = getResultNumber();
1359  SmallVector<Value> opResults;
1360  for (Operation *target : state.getPayloadOps(getTarget())) {
1361  if (resultNumber >= target->getNumResults()) {
1363  emitSilenceableError() << "targeted op does not have enough results";
1364  diag.attachNote(target->getLoc()) << "target op";
1365  return diag;
1366  }
1367  opResults.push_back(target->getOpResult(resultNumber));
1368  }
1369  results.setValues(llvm::cast<OpResult>(getResult()), opResults);
1371 }
1372 
1373 //===----------------------------------------------------------------------===//
1374 // GetTypeOp
1375 //===----------------------------------------------------------------------===//
1376 
1377 void transform::GetTypeOp::getEffects(
1378  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1379  onlyReadsHandle(getValue(), effects);
1380  producesHandle(getResult(), effects);
1381  onlyReadsPayload(effects);
1382 }
1383 
1385 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1386  transform::TransformResults &results,
1387  transform::TransformState &state) {
1388  SmallVector<Attribute> params;
1389  for (Value value : state.getPayloadValues(getValue())) {
1390  Type type = value.getType();
1391  if (getElemental()) {
1392  if (auto shaped = dyn_cast<ShapedType>(type)) {
1393  type = shaped.getElementType();
1394  }
1395  }
1396  params.push_back(TypeAttr::get(type));
1397  }
1398  results.setParams(getResult().cast<OpResult>(), params);
1400 }
1401 
1402 //===----------------------------------------------------------------------===//
1403 // IncludeOp
1404 //===----------------------------------------------------------------------===//
1405 
1406 /// Applies the transform ops contained in `block`. Maps `results` to the same
1407 /// values as the operands of the block terminator.
1409 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1411  transform::TransformResults &results) {
1412  // Apply the sequenced ops one by one.
1413  for (Operation &transform : block.without_terminator()) {
1415  state.applyTransform(cast<transform::TransformOpInterface>(transform));
1416  if (result.isDefiniteFailure())
1417  return result;
1418 
1419  if (result.isSilenceableFailure()) {
1420  if (mode == transform::FailurePropagationMode::Propagate) {
1421  // Propagate empty results in case of early exit.
1422  forwardEmptyOperands(&block, state, results);
1423  return result;
1424  }
1425  (void)result.silence();
1426  }
1427  }
1428 
1429  // Forward the operation mapping for values yielded from the sequence to the
1430  // values produced by the sequence op.
1431  transform::detail::forwardTerminatorOperands(&block, state, results);
1433 }
1434 
1436 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1437  transform::TransformResults &results,
1438  transform::TransformState &state) {
1439  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1440  getOperation(), getTarget());
1441  assert(callee && "unverified reference to unknown symbol");
1442 
1443  if (callee.isExternal())
1444  return emitDefiniteFailure() << "unresolved external named sequence";
1445 
1446  // Map operands to block arguments.
1448  detail::prepareValueMappings(mappings, getOperands(), state);
1449  auto scope = state.make_region_scope(callee.getBody());
1450  for (auto &&[arg, map] :
1451  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1452  if (failed(state.mapBlockArgument(arg, map)))
1454  }
1455 
1457  callee.getBody().front(), getFailurePropagationMode(), state, results);
1458  mappings.clear();
1460  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1461  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1462  results.setMappedValues(result, mapping);
1463  return result;
1464 }
1465 
1467 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1468 
1469 void transform::IncludeOp::getEffects(
1470  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1471  // Always mark as modifying the payload.
1472  // TODO: a mechanism to annotate effects on payload. Even when all handles are
1473  // only read, the payload may still be modified, so we currently stay on the
1474  // conservative side and always indicate modification. This may prevent some
1475  // code reordering.
1476  modifiesPayload(effects);
1477 
1478  // Results are always produced.
1479  producesHandle(getResults(), effects);
1480 
1481  // Adds default effects to operands and results. This will be added if
1482  // preconditions fail so the trait verifier doesn't complain about missing
1483  // effects and the real precondition failure is reported later on.
1484  auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
1485 
1486  // Bail if the callee is unknown. This may run as part of the verification
1487  // process before we verified the validity of the callee or of this op.
1488  auto target =
1489  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1490  if (!target)
1491  return defaultEffects();
1492  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1493  getOperation(), getTarget());
1494  if (!callee)
1495  return defaultEffects();
1496  DiagnosedSilenceableFailure earlyVerifierResult =
1497  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1498  if (!earlyVerifierResult.succeeded()) {
1499  (void)earlyVerifierResult.silence();
1500  return defaultEffects();
1501  }
1502 
1503  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1504  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1505  consumesHandle(getOperand(i), effects);
1506  else
1507  onlyReadsHandle(getOperand(i), effects);
1508  }
1509 }
1510 
1512 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1513  // Access through indirection and do additional checking because this may be
1514  // running before the main op verifier.
1515  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1516  if (!targetAttr)
1517  return emitOpError() << "expects a 'target' symbol reference attribute";
1518 
1519  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1520  *this, targetAttr);
1521  if (!target)
1522  return emitOpError() << "does not reference a named transform sequence";
1523 
1524  FunctionType fnType = target.getFunctionType();
1525  if (fnType.getNumInputs() != getNumOperands())
1526  return emitError("incorrect number of operands for callee");
1527 
1528  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1529  if (getOperand(i).getType() != fnType.getInput(i)) {
1530  return emitOpError("operand type mismatch: expected operand type ")
1531  << fnType.getInput(i) << ", but provided "
1532  << getOperand(i).getType() << " for operand number " << i;
1533  }
1534  }
1535 
1536  if (fnType.getNumResults() != getNumResults())
1537  return emitError("incorrect number of results for callee");
1538 
1539  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1540  Type resultType = getResult(i).getType();
1541  Type funcType = fnType.getResult(i);
1542  if (!implementSameTransformInterface(resultType, funcType)) {
1543  return emitOpError() << "type of result #" << i
1544  << " must implement the same transform dialect "
1545  "interface as the corresponding callee result";
1546  }
1547  }
1548 
1550  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1551  /*alsoVerifyInternal=*/true)
1552  .checkAndReport();
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // MatchOperationEmptyOp
1557 //===----------------------------------------------------------------------===//
1558 
1559 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1560  ::std::optional<::mlir::Operation *> maybeCurrent,
1562  if (!maybeCurrent.has_value()) {
1563  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1565  }
1566  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1567  return emitSilenceableError() << "operation is not empty";
1568 }
1569 
1570 //===----------------------------------------------------------------------===//
1571 // MatchOperationNameOp
1572 //===----------------------------------------------------------------------===//
1573 
1574 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1575  Operation *current, transform::TransformResults &results,
1576  transform::TransformState &state) {
1577  StringRef currentOpName = current->getName().getStringRef();
1578  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1579  if (acceptedAttr.getValue() == currentOpName)
1581  }
1582  return emitSilenceableError() << "wrong operation name";
1583 }
1584 
1585 //===----------------------------------------------------------------------===//
1586 // MatchParamCmpIOp
1587 //===----------------------------------------------------------------------===//
1588 
1590 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1591  transform::TransformResults &results,
1592  transform::TransformState &state) {
1593  auto signedAPIntAsString = [&](APInt value) {
1594  std::string str;
1595  llvm::raw_string_ostream os(str);
1596  value.print(os, /*isSigned=*/true);
1597  return os.str();
1598  };
1599 
1600  ArrayRef<Attribute> params = state.getParams(getParam());
1601  ArrayRef<Attribute> references = state.getParams(getReference());
1602 
1603  if (params.size() != references.size()) {
1604  return emitSilenceableError()
1605  << "parameters have different payload lengths (" << params.size()
1606  << " vs " << references.size() << ")";
1607  }
1608 
1609  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1610  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1611  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1612  if (!intAttr || !refAttr) {
1613  return emitDefiniteFailure()
1614  << "non-integer parameter value not expected";
1615  }
1616  if (intAttr.getType() != refAttr.getType()) {
1617  return emitDefiniteFailure()
1618  << "mismatching integer attribute types in parameter #" << i;
1619  }
1620  APInt value = intAttr.getValue();
1621  APInt refValue = refAttr.getValue();
1622 
1623  // TODO: this copy will not be necessary in C++20.
1624  int64_t position = i;
1625  auto reportError = [&](StringRef direction) {
1627  emitSilenceableError() << "expected parameter to be " << direction
1628  << " " << signedAPIntAsString(refValue)
1629  << ", got " << signedAPIntAsString(value);
1630  diag.attachNote(getParam().getLoc())
1631  << "value # " << position
1632  << " associated with the parameter defined here";
1633  return diag;
1634  };
1635 
1636  switch (getPredicate()) {
1637  case MatchCmpIPredicate::eq:
1638  if (value.eq(refValue))
1639  break;
1640  return reportError("equal to");
1641  case MatchCmpIPredicate::ne:
1642  if (value.ne(refValue))
1643  break;
1644  return reportError("not equal to");
1645  case MatchCmpIPredicate::lt:
1646  if (value.slt(refValue))
1647  break;
1648  return reportError("less than");
1649  case MatchCmpIPredicate::le:
1650  if (value.sle(refValue))
1651  break;
1652  return reportError("less than or equal to");
1653  case MatchCmpIPredicate::gt:
1654  if (value.sgt(refValue))
1655  break;
1656  return reportError("greater than");
1657  case MatchCmpIPredicate::ge:
1658  if (value.sge(refValue))
1659  break;
1660  return reportError("greater than or equal to");
1661  }
1662  }
1664 }
1665 
1666 void transform::MatchParamCmpIOp::getEffects(
1667  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1668  onlyReadsHandle(getParam(), effects);
1669  onlyReadsHandle(getReference(), effects);
1670 }
1671 
1672 //===----------------------------------------------------------------------===//
1673 // ParamConstantOp
1674 //===----------------------------------------------------------------------===//
1675 
1677 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
1678  transform::TransformResults &results,
1679  transform::TransformState &state) {
1680  results.setParams(cast<OpResult>(getParam()), {getValue()});
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // MergeHandlesOp
1686 //===----------------------------------------------------------------------===//
1687 
1689 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
1690  transform::TransformResults &results,
1691  transform::TransformState &state) {
1692  ValueRange handles = getHandles();
1693  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
1694  SmallVector<Operation *> operations;
1695  for (Value operand : handles)
1696  llvm::append_range(operations, state.getPayloadOps(operand));
1697  if (!getDeduplicate()) {
1698  results.set(llvm::cast<OpResult>(getResult()), operations);
1700  }
1701 
1702  SetVector<Operation *> uniqued(operations.begin(), operations.end());
1703  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1705  }
1706 
1707  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1708  SmallVector<Attribute> attrs;
1709  for (Value attribute : handles)
1710  llvm::append_range(attrs, state.getParams(attribute));
1711  if (!getDeduplicate()) {
1712  results.setParams(cast<OpResult>(getResult()), attrs);
1714  }
1715 
1716  SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
1717  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1719  }
1720 
1721  assert(
1722  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1723  "expected value handle type");
1724  SmallVector<Value> payloadValues;
1725  for (Value value : handles)
1726  llvm::append_range(payloadValues, state.getPayloadValues(value));
1727  if (!getDeduplicate()) {
1728  results.setValues(cast<OpResult>(getResult()), payloadValues);
1730  }
1731 
1732  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
1733  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1735 }
1736 
1737 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1738  // Handles may be the same if deduplicating is enabled.
1739  return getDeduplicate();
1740 }
1741 
1742 void transform::MergeHandlesOp::getEffects(
1743  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1744  onlyReadsHandle(getHandles(), effects);
1745  producesHandle(getResult(), effects);
1746 
1747  // There are no effects on the Payload IR as this is only a handle
1748  // manipulation.
1749 }
1750 
1751 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1752  if (getDeduplicate() || getHandles().size() != 1)
1753  return {};
1754 
1755  // If deduplication is not required and there is only one operand, it can be
1756  // used directly instead of merging.
1757  return getHandles().front();
1758 }
1759 
1760 //===----------------------------------------------------------------------===//
1761 // NamedSequenceOp
1762 //===----------------------------------------------------------------------===//
1763 
1765 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
1766  transform::TransformResults &results,
1767  transform::TransformState &state) {
1768  if (isExternal())
1769  return emitDefiniteFailure() << "unresolved external named sequence";
1770 
1771  // Map the entry block argument to the list of operations.
1772  // Note: this is the same implementation as PossibleTopLevelTransformOp but
1773  // without attaching the interface / trait since that is tailored to a
1774  // dangling top-level op that does not get "called".
1775  auto scope = state.make_region_scope(getBody());
1777  state, this->getOperation(), getBody())))
1779 
1780  return applySequenceBlock(getBody().front(),
1781  FailurePropagationMode::Propagate, state, results);
1782 }
1783 
1784 void transform::NamedSequenceOp::getEffects(
1785  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
1786 
1788  OperationState &result) {
1790  parser, result, /*allowVariadic=*/false,
1791  getFunctionTypeAttrName(result.name),
1792  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
1794  std::string &) { return builder.getFunctionType(inputs, results); },
1795  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1796 }
1797 
1800  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
1801  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
1802  getResAttrsAttrName());
1803 }
1804 
1805 /// Verifies that a symbol function-like transform dialect operation has the
1806 /// signature and the terminator that have conforming types, i.e., types
1807 /// implementing the same transform dialect type interface. If `allowExternal`
1808 /// is set, allow external symbols (declarations) and don't check the terminator
1809 /// as it may not exist.
1811 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
1812  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
1815  << "cannot be defined inside another transform op";
1816  diag.attachNote(parent.getLoc()) << "ancestor transform op";
1817  return diag;
1818  }
1819 
1820  if (op.isExternal() || op.getFunctionBody().empty()) {
1821  if (allowExternal)
1823 
1824  return emitSilenceableFailure(op) << "cannot be external";
1825  }
1826 
1827  if (op.getFunctionBody().front().empty())
1828  return emitSilenceableFailure(op) << "expected a non-empty body block";
1829 
1830  Operation *terminator = &op.getFunctionBody().front().back();
1831  if (!isa<transform::YieldOp>(terminator)) {
1833  << "expected '"
1834  << transform::YieldOp::getOperationName()
1835  << "' as terminator";
1836  diag.attachNote(terminator->getLoc()) << "terminator";
1837  return diag;
1838  }
1839 
1840  if (terminator->getNumOperands() != op.getResultTypes().size()) {
1841  return emitSilenceableFailure(terminator)
1842  << "expected terminator to have as many operands as the parent op "
1843  "has results";
1844  }
1845  for (auto [i, operandType, resultType] : llvm::zip_equal(
1846  llvm::seq<unsigned>(0, terminator->getNumOperands()),
1847  terminator->getOperands().getType(), op.getResultTypes())) {
1848  if (operandType == resultType)
1849  continue;
1850  return emitSilenceableFailure(terminator)
1851  << "the type of the terminator operand #" << i
1852  << " must match the type of the corresponding parent op result ("
1853  << operandType << " vs " << resultType << ")";
1854  }
1855 
1857 }
1858 
1859 /// Verification of a NamedSequenceOp. This does not report the error
1860 /// immediately, so it can be used to check for op's well-formedness before the
1861 /// verifier runs, e.g., during trait verification.
1863 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
1864  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
1865  if (!parent->getAttr(
1866  transform::TransformDialect::kWithNamedSequenceAttrName)) {
1869  << "expects the parent symbol table to have the '"
1870  << transform::TransformDialect::kWithNamedSequenceAttrName
1871  << "' attribute";
1872  diag.attachNote(parent->getLoc()) << "symbol table operation";
1873  return diag;
1874  }
1875  }
1876 
1877  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
1880  << "cannot be defined inside another transform op";
1881  diag.attachNote(parent.getLoc()) << "ancestor transform op";
1882  return diag;
1883  }
1884 
1885  if (op.isExternal() || op.getBody().empty())
1886  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
1887  emitWarnings);
1888 
1889  if (op.getBody().front().empty())
1890  return emitSilenceableFailure(op) << "expected a non-empty body block";
1891 
1892  Operation *terminator = &op.getBody().front().back();
1893  if (!isa<transform::YieldOp>(terminator)) {
1895  << "expected '"
1896  << transform::YieldOp::getOperationName()
1897  << "' as terminator";
1898  diag.attachNote(terminator->getLoc()) << "terminator";
1899  return diag;
1900  }
1901 
1902  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
1903  return emitSilenceableFailure(terminator)
1904  << "expected terminator to have as many operands as the parent op "
1905  "has results";
1906  }
1907  for (auto [i, operandType, resultType] :
1908  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
1909  terminator->getOperands().getType(),
1910  op.getFunctionType().getResults())) {
1911  if (operandType == resultType)
1912  continue;
1913  return emitSilenceableFailure(terminator)
1914  << "the type of the terminator operand #" << i
1915  << " must match the type of the corresponding parent op result ("
1916  << operandType << " vs " << resultType << ")";
1917  }
1918 
1919  auto funcOp = cast<FunctionOpInterface>(*op);
1921  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
1922  if (!diag.succeeded())
1923  return diag;
1924 
1925  return verifyYieldingSingleBlockOp(funcOp,
1926  /*allowExternal=*/true);
1927 }
1928 
1930  // Actual verification happens in a separate function for reusability.
1931  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
1932 }
1933 
1934 template <typename FnTy>
1935 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
1936  Type bbArgType, TypeRange extraBindingTypes,
1937  FnTy bodyBuilder) {
1938  SmallVector<Type> types;
1939  types.reserve(1 + extraBindingTypes.size());
1940  types.push_back(bbArgType);
1941  llvm::append_range(types, extraBindingTypes);
1942 
1943  OpBuilder::InsertionGuard guard(builder);
1944  Region *region = state.regions.back().get();
1945  Block *bodyBlock =
1946  builder.createBlock(region, region->begin(), types,
1947  SmallVector<Location>(types.size(), state.location));
1948 
1949  // Populate body.
1950  builder.setInsertionPointToStart(bodyBlock);
1951  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
1952  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
1953  } else {
1954  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
1955  bodyBlock->getArguments().drop_front());
1956  }
1957 }
1958 
1959 void transform::NamedSequenceOp::build(OpBuilder &builder,
1960  OperationState &state, StringRef symName,
1961  Type rootType, TypeRange resultTypes,
1962  SequenceBodyBuilderFn bodyBuilder,
1964  ArrayRef<DictionaryAttr> argAttrs) {
1965  state.addAttribute(SymbolTable::getSymbolAttrName(),
1966  builder.getStringAttr(symName));
1967  state.addAttribute(getFunctionTypeAttrName(state.name),
1969  rootType, resultTypes)));
1970  state.attributes.append(attrs.begin(), attrs.end());
1971  state.addRegion();
1972 
1973  buildSequenceBody(builder, state, rootType,
1974  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
1975 }
1976 
1977 //===----------------------------------------------------------------------===//
1978 // SelectOp
1979 //===----------------------------------------------------------------------===//
1980 
1982 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
1983  transform::TransformResults &results,
1984  transform::TransformState &state) {
1985  SmallVector<Operation *> result;
1986  auto payloadOps = state.getPayloadOps(getTarget());
1987  for (Operation *op : payloadOps) {
1988  if (op->getName().getStringRef() == getOpName())
1989  result.push_back(op);
1990  }
1991  results.set(cast<OpResult>(getResult()), result);
1993 }
1994 
1995 //===----------------------------------------------------------------------===//
1996 // SplitHandleOp
1997 //===----------------------------------------------------------------------===//
1998 
1999 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2000  Value target, int64_t numResultHandles) {
2001  result.addOperands(target);
2002  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2003 }
2004 
2006 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2007  transform::TransformResults &results,
2008  transform::TransformState &state) {
2009  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2010  auto produceNumOpsError = [&]() {
2011  return emitSilenceableError()
2012  << getHandle() << " expected to contain " << this->getNumResults()
2013  << " payload ops but it contains " << numPayloadOps
2014  << " payload ops";
2015  };
2016 
2017  // Fail if there are more payload ops than results and no overflow result was
2018  // specified.
2019  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2020  return produceNumOpsError();
2021 
2022  // Fail if there are more results than payload ops. Unless:
2023  // - "fail_on_payload_too_small" is set to "false", or
2024  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2025  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2026  !(numPayloadOps == 0 && getPassThroughEmptyHandle()))
2027  return produceNumOpsError();
2028 
2029  // Distribute payload ops.
2030  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2031  if (getOverflowResult())
2032  resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2033  getNumResults());
2034  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2035  int64_t resultNum = en.index();
2036  if (resultNum >= getNumResults())
2037  resultNum = *getOverflowResult();
2038  resultHandles[resultNum].push_back(en.value());
2039  }
2040 
2041  // Set transform op results.
2042  for (auto &&it : llvm::enumerate(resultHandles))
2043  results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2044 
2046 }
2047 
2048 void transform::SplitHandleOp::getEffects(
2049  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2050  onlyReadsHandle(getHandle(), effects);
2051  producesHandle(getResults(), effects);
2052  // There are no effects on the Payload IR as this is only a handle
2053  // manipulation.
2054 }
2055 
2057  if (getOverflowResult().has_value() &&
2058  !(*getOverflowResult() < getNumResults()))
2059  return emitOpError("overflow_result is not a valid result index");
2060  return success();
2061 }
2062 
2063 //===----------------------------------------------------------------------===//
2064 // ReplicateOp
2065 //===----------------------------------------------------------------------===//
2066 
2068 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2069  transform::TransformResults &results,
2070  transform::TransformState &state) {
2071  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2072  for (const auto &en : llvm::enumerate(getHandles())) {
2073  Value handle = en.value();
2074  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2075  SmallVector<Operation *> current =
2076  llvm::to_vector(state.getPayloadOps(handle));
2077  SmallVector<Operation *> payload;
2078  payload.reserve(numRepetitions * current.size());
2079  for (unsigned i = 0; i < numRepetitions; ++i)
2080  llvm::append_range(payload, current);
2081  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2082  } else {
2083  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2084  "expected param type");
2085  ArrayRef<Attribute> current = state.getParams(handle);
2086  SmallVector<Attribute> params;
2087  params.reserve(numRepetitions * current.size());
2088  for (unsigned i = 0; i < numRepetitions; ++i)
2089  llvm::append_range(params, current);
2090  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2091  params);
2092  }
2093  }
2095 }
2096 
2097 void transform::ReplicateOp::getEffects(
2098  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2099  onlyReadsHandle(getPattern(), effects);
2100  onlyReadsHandle(getHandles(), effects);
2101  producesHandle(getReplicated(), effects);
2102 }
2103 
2104 //===----------------------------------------------------------------------===//
2105 // SequenceOp
2106 //===----------------------------------------------------------------------===//
2107 
2109 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2110  transform::TransformResults &results,
2111  transform::TransformState &state) {
2112  // Map the entry block argument to the list of operations.
2113  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2114  if (failed(mapBlockArguments(state)))
2116 
2117  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2118  results);
2119 }
2120 
2122  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2123  Type &rootType,
2124  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2125  SmallVectorImpl<Type> &extraBindingTypes) {
2126  OpAsmParser::UnresolvedOperand rootOperand;
2127  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2128  if (!hasRoot.has_value()) {
2129  root = std::nullopt;
2130  return success();
2131  }
2132  if (failed(hasRoot.value()))
2133  return failure();
2134  root = rootOperand;
2135 
2136  if (succeeded(parser.parseOptionalComma())) {
2137  if (failed(parser.parseOperandList(extraBindings)))
2138  return failure();
2139  }
2140  if (failed(parser.parseColon()))
2141  return failure();
2142 
2143  // The paren is truly optional.
2144  (void)parser.parseOptionalLParen();
2145 
2146  if (failed(parser.parseType(rootType))) {
2147  return failure();
2148  }
2149 
2150  if (!extraBindings.empty()) {
2151  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2152  return failure();
2153  }
2154 
2155  if (extraBindingTypes.size() != extraBindings.size()) {
2156  return parser.emitError(parser.getNameLoc(),
2157  "expected types to be provided for all operands");
2158  }
2159 
2160  // The paren is truly optional.
2161  (void)parser.parseOptionalRParen();
2162  return success();
2163 }
2164 
2166  Value root, Type rootType,
2167  ValueRange extraBindings,
2168  TypeRange extraBindingTypes) {
2169  if (!root)
2170  return;
2171 
2172  printer << root;
2173  bool hasExtras = !extraBindings.empty();
2174  if (hasExtras) {
2175  printer << ", ";
2176  printer.printOperands(extraBindings);
2177  }
2178 
2179  printer << " : ";
2180  if (hasExtras)
2181  printer << "(";
2182 
2183  printer << rootType;
2184  if (hasExtras) {
2185  printer << ", ";
2186  llvm::interleaveComma(extraBindingTypes, printer.getStream());
2187  printer << ")";
2188  }
2189 }
2190 
2191 /// Returns `true` if the given op operand may be consuming the handle value in
2192 /// the Transform IR. That is, if it may have a Free effect on it.
2194  // Conservatively assume the effect being present in absence of the interface.
2195  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2196  if (!iface)
2197  return true;
2198 
2199  return isHandleConsumed(use.get(), iface);
2200 }
2201 
2204  function_ref<InFlightDiagnostic()> reportError) {
2205  OpOperand *potentialConsumer = nullptr;
2206  for (OpOperand &use : value.getUses()) {
2207  if (!isValueUsePotentialConsumer(use))
2208  continue;
2209 
2210  if (!potentialConsumer) {
2211  potentialConsumer = &use;
2212  continue;
2213  }
2214 
2215  InFlightDiagnostic diag = reportError()
2216  << " has more than one potential consumer";
2217  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2218  << "used here as operand #" << potentialConsumer->getOperandNumber();
2219  diag.attachNote(use.getOwner()->getLoc())
2220  << "used here as operand #" << use.getOperandNumber();
2221  return diag;
2222  }
2223 
2224  return success();
2225 }
2226 
2228  assert(getBodyBlock()->getNumArguments() >= 1 &&
2229  "the number of arguments must have been verified to be more than 1 by "
2230  "PossibleTopLevelTransformOpTrait");
2231 
2232  if (!getRoot() && !getExtraBindings().empty()) {
2233  return emitOpError()
2234  << "does not expect extra operands when used as top-level";
2235  }
2236 
2237  // Check if a block argument has more than one consuming use.
2238  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2239  if (failed(checkDoubleConsume(arg, [this, arg]() {
2240  return (emitOpError() << "block argument #" << arg.getArgNumber());
2241  }))) {
2242  return failure();
2243  }
2244  }
2245 
2246  // Check properties of the nested operations they cannot check themselves.
2247  for (Operation &child : *getBodyBlock()) {
2248  if (!isa<TransformOpInterface>(child) &&
2249  &child != &getBodyBlock()->back()) {
2251  emitOpError()
2252  << "expected children ops to implement TransformOpInterface";
2253  diag.attachNote(child.getLoc()) << "op without interface";
2254  return diag;
2255  }
2256 
2257  for (OpResult result : child.getResults()) {
2258  auto report = [&]() {
2259  return (child.emitError() << "result #" << result.getResultNumber());
2260  };
2261  if (failed(checkDoubleConsume(result, report)))
2262  return failure();
2263  }
2264  }
2265 
2266  if (!getBodyBlock()->mightHaveTerminator())
2267  return emitOpError() << "expects to have a terminator in the body";
2268 
2269  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2270  getOperation()->getResultTypes()) {
2271  InFlightDiagnostic diag = emitOpError()
2272  << "expects the types of the terminator operands "
2273  "to match the types of the result";
2274  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2275  return diag;
2276  }
2277  return success();
2278 }
2279 
2280 void transform::SequenceOp::getEffects(
2281  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2282  getPotentialTopLevelEffects(effects);
2283 }
2284 
2286 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2287  assert(point == getBody() && "unexpected region index");
2288  if (getOperation()->getNumOperands() > 0)
2289  return getOperation()->getOperands();
2290  return OperandRange(getOperation()->operand_end(),
2291  getOperation()->operand_end());
2292 }
2293 
2294 void transform::SequenceOp::getSuccessorRegions(
2295  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2296  if (point.isParent()) {
2297  Region *bodyRegion = &getBody();
2298  regions.emplace_back(bodyRegion, getNumOperands() != 0
2299  ? bodyRegion->getArguments()
2301  return;
2302  }
2303 
2304  assert(point == getBody() && "unexpected region index");
2305  regions.emplace_back(getOperation()->getResults());
2306 }
2307 
2308 void transform::SequenceOp::getRegionInvocationBounds(
2309  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2310  (void)operands;
2311  bounds.emplace_back(1, 1);
2312 }
2313 
2314 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2315  TypeRange resultTypes,
2316  FailurePropagationMode failurePropagationMode,
2317  Value root,
2318  SequenceBodyBuilderFn bodyBuilder) {
2319  build(builder, state, resultTypes, failurePropagationMode, root,
2320  /*extra_bindings=*/ValueRange());
2321  Type bbArgType = root.getType();
2322  buildSequenceBody(builder, state, bbArgType,
2323  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2324 }
2325 
2326 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2327  TypeRange resultTypes,
2328  FailurePropagationMode failurePropagationMode,
2329  Value root, ValueRange extraBindings,
2330  SequenceBodyBuilderArgsFn bodyBuilder) {
2331  build(builder, state, resultTypes, failurePropagationMode, root,
2332  extraBindings);
2333  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2334  bodyBuilder);
2335 }
2336 
2337 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2338  TypeRange resultTypes,
2339  FailurePropagationMode failurePropagationMode,
2340  Type bbArgType,
2341  SequenceBodyBuilderFn bodyBuilder) {
2342  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2343  /*extra_bindings=*/ValueRange());
2344  buildSequenceBody(builder, state, bbArgType,
2345  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2346 }
2347 
2348 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2349  TypeRange resultTypes,
2350  FailurePropagationMode failurePropagationMode,
2351  Type bbArgType, TypeRange extraBindingTypes,
2352  SequenceBodyBuilderArgsFn bodyBuilder) {
2353  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2354  /*extra_bindings=*/ValueRange());
2355  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2356 }
2357 
2358 //===----------------------------------------------------------------------===//
2359 // PrintOp
2360 //===----------------------------------------------------------------------===//
2361 
2362 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2363  StringRef name) {
2364  if (!name.empty())
2365  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2366 }
2367 
2368 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2369  Value target, StringRef name) {
2370  result.addOperands({target});
2371  build(builder, result, name);
2372 }
2373 
2375 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2376  transform::TransformResults &results,
2377  transform::TransformState &state) {
2378  llvm::outs() << "[[[ IR printer: ";
2379  if (getName().has_value())
2380  llvm::outs() << *getName() << " ";
2381 
2382  if (!getTarget()) {
2383  llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
2385  }
2386 
2387  llvm::outs() << "]]]\n";
2388  for (Operation *target : state.getPayloadOps(getTarget()))
2389  llvm::outs() << *target << "\n";
2390 
2392 }
2393 
2394 void transform::PrintOp::getEffects(
2395  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2396  onlyReadsHandle(getTarget(), effects);
2397  onlyReadsPayload(effects);
2398 
2399  // There is no resource for stderr file descriptor, so just declare print
2400  // writes into the default resource.
2401  effects.emplace_back(MemoryEffects::Write::get());
2402 }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // VerifyOp
2406 //===----------------------------------------------------------------------===//
2407 
2409 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2410  Operation *target,
2412  transform::TransformState &state) {
2413  if (failed(::mlir::verify(target))) {
2415  << "failed to verify payload op";
2416  diag.attachNote(target->getLoc()) << "payload op";
2417  return diag;
2418  }
2420 }
2421 
2422 void transform::VerifyOp::getEffects(
2423  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2424  transform::onlyReadsHandle(getTarget(), effects);
2425 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // YieldOp
2429 //===----------------------------------------------------------------------===//
2430 
2431 void transform::YieldOp::getEffects(
2432  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2433  onlyReadsHandle(getOperands(), effects);
2434 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static DiagnosedSilenceableFailure matchBlock(Block &block, Operation *op, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block assigning op as the payload of the block's first argu...
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:72
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A class for computing basic dominance information.
Definition: Dominance.h:121
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:453
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
type_range getTypes() const
Definition: ValueRange.cpp:26
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:686
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:211
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:55
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
Definition: TypeRange.h:145
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Applies the specified rewrite patterns on ops while also trying to fold these ops.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.