MLIR 23.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/ErrorHandling.h"
43#include "llvm/Support/InterleavedRange.h"
44#include <optional>
45
46#define DEBUG_TYPE "transform-dialect"
47#define DEBUG_TYPE_MATCHER "transform-matcher"
48
49using namespace mlir;
50
51static ParseResult parseApplyRegisteredPassOptions(
52 OpAsmParser &parser, DictionaryAttr &options,
55 Operation *op,
56 DictionaryAttr options,
57 ValueRange dynamicOptions);
58static ParseResult parseSequenceOpOperands(
59 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
60 Type &rootType,
62 SmallVectorImpl<Type> &extraBindingTypes);
63static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
64 Value root, Type rootType,
65 ValueRange extraBindings,
66 TypeRange extraBindingTypes);
67static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
69static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
71 ArrayAttr &actions);
72
73/// Helper function to check if the given transform op is contained in (or
74/// equal to) the given payload target op. In that case, an error is returned.
75/// Transforming transform IR that is currently executing is generally unsafe.
77ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
78 Operation *payload) {
79 Operation *transformAncestor = transform.getOperation();
80 while (transformAncestor) {
81 if (transformAncestor == payload) {
83 transform.emitDefiniteFailure()
84 << "cannot apply transform to itself (or one of its ancestors)";
85 diag.attachNote(payload->getLoc()) << "target payload op";
86 return diag;
87 }
88 transformAncestor = transformAncestor->getParentOp();
89 }
91}
92
93#define GET_OP_CLASSES
94#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
95
96//===----------------------------------------------------------------------===//
97// AlternativesOp
98//===----------------------------------------------------------------------===//
99
100OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
101 RegionSuccessor successor) {
102 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
103 return getOperation()->getOperands();
104 return OperandRange(getOperation()->operand_end(),
105 getOperation()->operand_end());
106}
107
108void transform::AlternativesOp::getSuccessorRegions(
110 for (Region &alternative : llvm::drop_begin(
111 getAlternatives(), point.isParent()
112 ? 0
114 ->getParentRegion()
115 ->getRegionNumber() +
116 1)) {
117 regions.emplace_back(&alternative);
118 }
119 if (!point.isParent())
120 regions.push_back(RegionSuccessor::parent());
121}
122
124transform::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) {
125 if (successor.isParent())
126 return getOperation()->getResults();
127 return successor.getSuccessor()->getArguments();
128}
129
130void transform::AlternativesOp::getRegionInvocationBounds(
132 (void)operands;
133 // The region corresponding to the first alternative is always executed, the
134 // remaining may or may not be executed.
135 bounds.reserve(getNumRegions());
136 bounds.emplace_back(1, 1);
137 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
138}
139
142 for (const auto &res : block->getParentOp()->getOpResults())
143 results.set(res, {});
144}
145
147transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
150 SmallVector<Operation *> originals;
151 if (Value scopeHandle = getScope())
152 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
153 else
154 originals.push_back(state.getTopLevel());
155
156 for (Operation *original : originals) {
157 if (original->isAncestor(getOperation())) {
159 << "scope must not contain the transforms being applied";
160 diag.attachNote(original->getLoc()) << "scope";
161 return diag;
162 }
163 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
165 << "only isolated-from-above ops can be alternative scopes";
166 diag.attachNote(original->getLoc()) << "scope";
167 return diag;
168 }
169 }
170
171 for (Region &reg : getAlternatives()) {
172 // Clone the scope operations and make the transforms in this alternative
173 // region apply to them by virtue of mapping the block argument (the only
174 // visible handle) to the cloned scope operations. This effectively prevents
175 // the transformation from accessing any IR outside the scope.
176 auto scope = state.make_region_scope(reg);
177 auto clones = llvm::map_to_vector(
178 originals, [](Operation *op) { return op->clone(); });
179 llvm::scope_exit deleteClones([&] {
180 for (Operation *clone : clones)
181 clone->erase();
182 });
183 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
185
186 bool failed = false;
187 for (Operation &transform : reg.front().without_terminator()) {
189 state.applyTransform(cast<TransformOpInterface>(transform));
190 if (result.isSilenceableFailure()) {
191 LDBG() << "alternative failed: " << result.getMessage();
192 failed = true;
193 break;
194 }
195
196 if (::mlir::failed(result.silence()))
198 }
199
200 // If all operations in the given alternative succeeded, no need to consider
201 // the rest. Replace the original scoping operation with the clone on which
202 // the transformations were performed.
203 if (!failed) {
204 // We will be using the clones, so cancel their scheduled deletion.
205 deleteClones.release();
206 TrackingListener listener(state, *this);
207 IRRewriter rewriter(getContext(), &listener);
208 for (const auto &kvp : llvm::zip(originals, clones)) {
209 Operation *original = std::get<0>(kvp);
210 Operation *clone = std::get<1>(kvp);
211 original->getBlock()->getOperations().insert(original->getIterator(),
212 clone);
213 rewriter.replaceOp(original, clone->getResults());
214 }
215 detail::forwardTerminatorOperands(&reg.front(), state, results);
217 }
218 }
219 return emitSilenceableError() << "all alternatives failed";
220}
221
222void transform::AlternativesOp::getEffects(
224 consumesHandle(getOperation()->getOpOperands(), effects);
225 producesHandle(getOperation()->getOpResults(), effects);
226 for (Region *region : getRegions()) {
227 if (!region->empty())
228 producesHandle(region->front().getArguments(), effects);
229 }
230 modifiesPayload(effects);
231}
232
233LogicalResult transform::AlternativesOp::verify() {
234 for (Region &alternative : getAlternatives()) {
235 Block &block = alternative.front();
236 Operation *terminator = block.getTerminator();
237 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
239 << "expects terminator operands to have the "
240 "same type as results of the operation";
241 diag.attachNote(terminator->getLoc()) << "terminator";
242 return diag;
243 }
244 }
245
246 return success();
247}
248
249//===----------------------------------------------------------------------===//
250// AnnotateOp
251//===----------------------------------------------------------------------===//
252
254transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
258 llvm::to_vector(state.getPayloadOps(getTarget()));
259
260 Attribute attr = UnitAttr::get(getContext());
261 if (auto paramH = getParam()) {
262 ArrayRef<Attribute> params = state.getParams(paramH);
263 if (params.size() != 1) {
264 if (targets.size() != params.size()) {
265 return emitSilenceableError()
266 << "parameter and target have different payload lengths ("
267 << params.size() << " vs " << targets.size() << ")";
268 }
269 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
270 target->setAttr(getName(), attr);
272 }
273 attr = params[0];
274 }
275 for (auto *target : targets)
276 target->setAttr(getName(), attr);
278}
279
280void transform::AnnotateOp::getEffects(
282 onlyReadsHandle(getTargetMutable(), effects);
283 onlyReadsHandle(getParamMutable(), effects);
284 modifiesPayload(effects);
285}
286
287//===----------------------------------------------------------------------===//
288// ApplyCommonSubexpressionEliminationOp
289//===----------------------------------------------------------------------===//
290
292transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
294 ApplyToEachResultList &results, transform::TransformState &state) {
295 // Make sure that this transform is not applied to itself. Modifying the
296 // transform IR while it is being interpreted is generally dangerous.
297 DiagnosedSilenceableFailure payloadCheck =
299 if (!payloadCheck.succeeded())
300 return payloadCheck;
301
302 DominanceInfo domInfo;
305}
306
307void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
309 transform::onlyReadsHandle(getTargetMutable(), effects);
311}
312
313//===----------------------------------------------------------------------===//
314// ApplyDeadCodeEliminationOp
315//===----------------------------------------------------------------------===//
316
317DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
319 ApplyToEachResultList &results, transform::TransformState &state) {
320 // Make sure that this transform is not applied to itself. Modifying the
321 // transform IR while it is being interpreted is generally dangerous.
322 DiagnosedSilenceableFailure payloadCheck =
324 if (!payloadCheck.succeeded())
325 return payloadCheck;
326
327 // Maintain a worklist of potentially dead ops.
328 SetVector<Operation *> worklist;
329
330 // Helper function that adds all defining ops of used values (operands and
331 // operands of nested ops).
332 auto addDefiningOpsToWorklist = [&](Operation *op) {
333 op->walk([&](Operation *op) {
334 for (Value v : op->getOperands())
335 if (Operation *defOp = v.getDefiningOp())
336 if (target->isProperAncestor(defOp))
337 worklist.insert(defOp);
338 });
339 };
340
341 // Helper function that erases an op.
342 auto eraseOp = [&](Operation *op) {
343 // Remove op and nested ops from the worklist.
344 op->walk([&](Operation *op) {
345 const auto *it = llvm::find(worklist, op);
346 if (it != worklist.end())
347 worklist.erase(it);
348 });
349 rewriter.eraseOp(op);
350 };
351
352 // Initial walk over the IR.
353 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
354 if (op != target && isOpTriviallyDead(op)) {
355 addDefiningOpsToWorklist(op);
356 eraseOp(op);
357 }
358 });
359
360 // Erase all ops that have become dead.
361 while (!worklist.empty()) {
362 Operation *op = worklist.pop_back_val();
363 if (!isOpTriviallyDead(op))
364 continue;
365 addDefiningOpsToWorklist(op);
366 eraseOp(op);
367 }
368
370}
371
372void transform::ApplyDeadCodeEliminationOp::getEffects(
374 transform::onlyReadsHandle(getTargetMutable(), effects);
376}
377
378//===----------------------------------------------------------------------===//
379// ApplyPatternsOp
380//===----------------------------------------------------------------------===//
381
382DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
384 ApplyToEachResultList &results, transform::TransformState &state) {
385 // Make sure that this transform is not applied to itself. Modifying the
386 // transform IR while it is being interpreted is generally dangerous. Even
387 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
388 // performs many additional simplifications such as dead code elimination.
389 DiagnosedSilenceableFailure payloadCheck =
391 if (!payloadCheck.succeeded())
392 return payloadCheck;
393
394 // Gather all specified patterns.
395 MLIRContext *ctx = target->getContext();
396 RewritePatternSet patterns(ctx);
397 if (!getRegion().empty()) {
398 for (Operation &op : getRegion().front()) {
399 cast<transform::PatternDescriptorOpInterface>(&op)
400 .populatePatternsWithState(patterns, state);
401 }
402 }
403
404 // Configure the GreedyPatternRewriteDriver.
405 GreedyRewriteConfig config;
406 config.setListener(
407 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
408 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
409
410 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
412 : getMaxIterations());
413 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
415 : getMaxNumRewrites());
416
417 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
418 // Op is isolated from above. The greedy driver iterates to a fixpoint
419 // internally and optionally runs full CSE between iterations.
420 config.enableCSEBetweenIterations(getApplyCse());
421 if (failed(applyPatternsGreedily(target, frozenPatterns, config))) {
423 << "greedy pattern application failed";
424 }
426 }
427
428 // Non-isolated case: gather the ops manually because the op-list
429 // GreedyPatternRewriteDriver overload only performs a single iteration and
430 // does not simplify regions. CSE is driven externally to reach a fixpoint.
432 target->walk([&](Operation *nestedOp) {
433 if (target != nestedOp)
434 ops.push_back(nestedOp);
435 });
436
437 // One or two iterations should be sufficient. Stop iterating after a certain
438 // threshold to make debugging easier.
439 static const int64_t kNumMaxIterations = 50;
440 int64_t iteration = 0;
441 bool cseChanged = false;
442 do {
443 if (failed(applyOpPatternsGreedily(ops, frozenPatterns, config))) {
445 << "greedy pattern application failed";
446 }
447
448 if (getApplyCse()) {
449 DominanceInfo domInfo;
451 &cseChanged);
452 }
453 } while (cseChanged && ++iteration < kNumMaxIterations);
454
455 if (iteration == kNumMaxIterations)
456 return emitDefiniteFailure() << "fixpoint iteration did not converge";
457
459}
460
461LogicalResult transform::ApplyPatternsOp::verify() {
462 if (!getRegion().empty()) {
463 for (Operation &op : getRegion().front()) {
464 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
466 << "expected children ops to implement "
467 "PatternDescriptorOpInterface";
468 diag.attachNote(op.getLoc()) << "op without interface";
469 return diag;
470 }
471 }
472 }
473 return success();
474}
475
476void transform::ApplyPatternsOp::getEffects(
478 transform::onlyReadsHandle(getTargetMutable(), effects);
480}
481
482void transform::ApplyPatternsOp::build(
484 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
485 result.addOperands(target);
486
487 OpBuilder::InsertionGuard g(builder);
488 Region *region = result.addRegion();
489 builder.createBlock(region);
490 if (bodyBuilder)
491 bodyBuilder(builder, result.location);
492}
493
494//===----------------------------------------------------------------------===//
495// ApplyCanonicalizationPatternsOp
496//===----------------------------------------------------------------------===//
497
498void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
499 RewritePatternSet &patterns) {
500 MLIRContext *ctx = patterns.getContext();
501 for (Dialect *dialect : ctx->getLoadedDialects())
502 dialect->getCanonicalizationPatterns(patterns);
504 op.getCanonicalizationPatterns(patterns, ctx);
505}
506
507//===----------------------------------------------------------------------===//
508// ApplyConversionPatternsOp
509//===----------------------------------------------------------------------===//
510
511DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
514 MLIRContext *ctx = getContext();
515
516 // Instantiate the default type converter if a type converter builder is
517 // specified.
518 std::unique_ptr<TypeConverter> defaultTypeConverter;
519 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
520 getDefaultTypeConverter();
521 if (typeConverterBuilder)
522 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
523
524 // Configure conversion target.
525 ConversionTarget conversionTarget(*getContext());
526 if (getLegalOps())
527 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
528 conversionTarget.addLegalOp(
529 OperationName(cast<StringAttr>(attr).getValue(), ctx));
530 if (getIllegalOps())
531 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
532 conversionTarget.addIllegalOp(
533 OperationName(cast<StringAttr>(attr).getValue(), ctx));
534 if (getLegalDialects())
535 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
536 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
537 if (getIllegalDialects())
538 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
539 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
540
541 // Gather all specified patterns.
542 RewritePatternSet patterns(ctx);
543 // Need to keep the converters alive until after pattern application because
544 // the patterns take a reference to an object that would otherwise get out of
545 // scope.
547 if (!getPatterns().empty()) {
548 for (Operation &op : getPatterns().front()) {
549 auto descriptor =
550 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
551
552 // Check if this pattern set specifies a type converter.
553 std::unique_ptr<TypeConverter> typeConverter =
554 descriptor.getTypeConverter();
555 TypeConverter *converter = nullptr;
556 if (typeConverter) {
557 keepAliveConverters.emplace_back(std::move(typeConverter));
558 converter = keepAliveConverters.back().get();
559 } else {
560 // No type converter specified: Use the default type converter.
561 if (!defaultTypeConverter) {
563 << "pattern descriptor does not specify type "
564 "converter and apply_conversion_patterns op has "
565 "no default type converter";
566 diag.attachNote(op.getLoc()) << "pattern descriptor op";
567 return diag;
568 }
569 converter = defaultTypeConverter.get();
570 }
571
572 // Add descriptor-specific updates to the conversion target, which may
573 // depend on the final type converter. In structural converters, the
574 // legality of types dictates the dynamic legality of an operation.
575 descriptor.populateConversionTargetRules(*converter, conversionTarget);
576
577 descriptor.populatePatterns(*converter, patterns);
578 }
579 }
580
581 // Attach a tracking listener if handles should be preserved. We configure the
582 // listener to allow op replacements with different names, as conversion
583 // patterns typically replace ops with replacement ops that have a different
584 // name.
585 TrackingListenerConfig trackingConfig;
586 trackingConfig.requireMatchingReplacementOpName = false;
587 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
588 ConversionConfig conversionConfig;
589 if (getPreserveHandles())
590 conversionConfig.listener = &trackingListener;
591
592 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
593 for (Operation *target : state.getPayloadOps(getTarget())) {
594 // Make sure that this transform is not applied to itself. Modifying the
595 // transform IR while it is being interpreted is generally dangerous.
596 DiagnosedSilenceableFailure payloadCheck =
598 if (!payloadCheck.succeeded())
599 return payloadCheck;
600
601 LogicalResult status = failure();
602 if (getPartialConversion()) {
603 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
604 conversionConfig);
605 } else {
606 status = applyFullConversion(target, conversionTarget, frozenPatterns,
607 conversionConfig);
608 }
609
610 // Check dialect conversion state.
612 if (failed(status)) {
613 diag = emitSilenceableError() << "dialect conversion failed";
614 diag.attachNote(target->getLoc()) << "target op";
615 }
616
617 // Check tracking listener error state.
618 DiagnosedSilenceableFailure trackingFailure =
619 trackingListener.checkAndResetError();
620 if (!trackingFailure.succeeded()) {
621 if (diag.succeeded()) {
622 // Tracking failure is the only failure.
623 return trackingFailure;
624 }
625 diag.attachNote() << "tracking listener also failed: "
626 << trackingFailure.getMessage();
627 (void)trackingFailure.silence();
628 }
629
630 if (!diag.succeeded())
631 return diag;
632 }
633
635}
636
637LogicalResult transform::ApplyConversionPatternsOp::verify() {
638 if (getNumRegions() != 1 && getNumRegions() != 2)
639 return emitOpError() << "expected 1 or 2 regions";
640 if (!getPatterns().empty()) {
641 for (Operation &op : getPatterns().front()) {
642 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
644 emitOpError() << "expected pattern children ops to implement "
645 "ConversionPatternDescriptorOpInterface";
646 diag.attachNote(op.getLoc()) << "op without interface";
647 return diag;
648 }
649 }
650 }
651 if (getNumRegions() == 2) {
652 Region &typeConverterRegion = getRegion(1);
653 if (!llvm::hasSingleElement(typeConverterRegion.front()))
654 return emitOpError()
655 << "expected exactly one op in default type converter region";
656 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
657 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
658 maybeTypeConverter);
659 if (!typeConverterOp) {
661 << "expected default converter child op to "
662 "implement TypeConverterBuilderOpInterface";
663 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
664 return diag;
665 }
666 // Check default type converter type.
667 if (!getPatterns().empty()) {
668 for (Operation &op : getPatterns().front()) {
669 auto descriptor =
670 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
671 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
672 return failure();
673 }
674 }
675 }
676 return success();
677}
678
679void transform::ApplyConversionPatternsOp::getEffects(
681 if (!getPreserveHandles()) {
682 transform::consumesHandle(getTargetMutable(), effects);
683 } else {
684 transform::onlyReadsHandle(getTargetMutable(), effects);
685 }
687}
688
689void transform::ApplyConversionPatternsOp::build(
691 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
692 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
693 result.addOperands(target);
694
695 {
696 OpBuilder::InsertionGuard g(builder);
697 Region *region1 = result.addRegion();
698 builder.createBlock(region1);
699 if (patternsBodyBuilder)
700 patternsBodyBuilder(builder, result.location);
701 }
702 {
703 OpBuilder::InsertionGuard g(builder);
704 Region *region2 = result.addRegion();
705 builder.createBlock(region2);
706 if (typeConverterBodyBuilder)
707 typeConverterBodyBuilder(builder, result.location);
708 }
709}
710
711//===----------------------------------------------------------------------===//
712// ApplyToLLVMConversionPatternsOp
713//===----------------------------------------------------------------------===//
714
715void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
716 TypeConverter &typeConverter, RewritePatternSet &patterns) {
717 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
718 assert(dialect && "expected that dialect is loaded");
719 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
720 // ConversionTarget is currently ignored because the enclosing
721 // apply_conversion_patterns op sets up its own ConversionTarget.
723 iface->populateConvertToLLVMConversionPatterns(
724 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
725}
726
727LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
728 transform::TypeConverterBuilderOpInterface builder) {
729 if (builder.getTypeConverterType() != "LLVMTypeConverter")
730 return emitOpError("expected LLVMTypeConverter");
731 return success();
732}
733
734LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
735 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
736 if (!dialect)
737 return emitOpError("unknown dialect or dialect not loaded: ")
738 << getDialectName();
739 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
740 if (!iface)
741 return emitOpError(
742 "dialect does not implement ConvertToLLVMPatternInterface or "
743 "extension was not loaded: ")
744 << getDialectName();
745 return success();
746}
747
748//===----------------------------------------------------------------------===//
749// ApplyLoopInvariantCodeMotionOp
750//===----------------------------------------------------------------------===//
751
753transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
754 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
757 // Currently, LICM does not remove operations, so we don't need tracking.
758 // If this ever changes, add a LICM entry point that takes a rewriter.
761}
762
763void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
765 transform::onlyReadsHandle(getTargetMutable(), effects);
767}
768
769//===----------------------------------------------------------------------===//
770// ApplyRegisteredPassOp
771//===----------------------------------------------------------------------===//
772
773void transform::ApplyRegisteredPassOp::getEffects(
775 consumesHandle(getTargetMutable(), effects);
776 onlyReadsHandle(getDynamicOptionsMutable(), effects);
777 producesHandle(getOperation()->getOpResults(), effects);
778 modifiesPayload(effects);
779}
780
782transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
785 // Obtain a single options-string to pass to the pass(-pipeline) from options
786 // passed in as a dictionary of keys mapping to values which are either
787 // attributes or param-operands pointing to attributes.
788 OperandRange dynamicOptions = getDynamicOptions();
789
790 std::string options;
791 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
792
793 // A helper to convert an option's attribute value into a corresponding
794 // string representation, with the ability to obtain the attr(s) from a param.
795 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
796 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
797 // The corresponding value attribute(s) is/are passed in via a param.
798 // Obtain the param-operand via its specified index.
799 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
800 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
801 "the number of ParamOperandAttrs in the options DictionaryAttr"
802 "should be the same as the number of options passed as params");
803 ArrayRef<Attribute> attrsAssociatedToParam =
804 state.getParams(dynamicOptions[dynamicOptionIdx]);
805 // Recursive so as to append all attrs associated to the param.
806 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
807 ",");
808 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
809 // Recursive so as to append all nested attrs of the array.
810 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
811 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
812 // Convert to unquoted string.
813 optionsStream << strAttr.getValue().str();
814 } else {
815 // For all other attributes, ask the attr to print itself (without type).
816 valueAttr.print(optionsStream, /*elideType=*/true);
817 }
818 };
819
820 // Convert the options DictionaryAttr into a single string.
821 llvm::interleave(
822 getOptions(), optionsStream,
823 [&](auto namedAttribute) {
824 optionsStream << namedAttribute.getName().str(); // Append the key.
825 optionsStream << "="; // And the key-value separator.
826 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
827 },
828 " ");
829 optionsStream.flush();
830
831 // Get pass or pass pipeline from registry.
832 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
833 if (!info)
834 info = PassInfo::lookup(getPassName());
835 if (!info)
836 return emitDefiniteFailure()
837 << "unknown pass or pass pipeline: " << getPassName();
838
839 // Create pass manager and add the pass or pass pipeline.
841 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
842 emitError(msg);
843 return failure();
844 }))) {
845 return emitDefiniteFailure()
846 << "failed to add pass or pass pipeline to pipeline: "
847 << getPassName();
848 }
849
850 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
851 for (Operation *target : targets) {
852 // Make sure that this transform is not applied to itself. Modifying the
853 // transform IR while it is being interpreted is generally dangerous. Even
854 // more so when applying passes because they may perform a wide range of IR
855 // modifications.
856 DiagnosedSilenceableFailure payloadCheck =
858 if (!payloadCheck.succeeded())
859 return payloadCheck;
860
861 // Run the pass or pass pipeline on the current target operation.
862 if (failed(pm.run(target))) {
863 auto diag = emitSilenceableError() << "pass pipeline failed";
864 diag.attachNote(target->getLoc()) << "target op";
865 return diag;
866 }
867 }
868
869 // The applied pass will have directly modified the payload IR(s).
870 results.set(llvm::cast<OpResult>(getResult()), targets);
872}
873
875 OpAsmParser &parser, DictionaryAttr &options,
877 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
878 SmallVector<NamedAttribute> keyValuePairs;
879 size_t dynamicOptionsIdx = 0;
880
881 // Helper for allowing parsing of option values which can be of the form:
882 // - a normal attribute
883 // - an operand (which would be converted to an attr referring to the operand)
884 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
885 std::function<ParseResult(Attribute &)> parseValue =
886 [&](Attribute &valueAttr) -> ParseResult {
887 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
888 if (succeeded(parser.parseOptionalLSquare())) {
890
891 // Recursively parse the array's elements, which might be operands.
892 if (parser.parseCommaSeparatedList(
894 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
895 " in options dictionary") ||
896 parser.parseRSquare())
897 return failure(); // NB: Attempted parse should've output error message.
898
899 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
900
901 return success();
902 }
903
904 // Parse the value, which can be either an attribute or an operand.
905 OptionalParseResult parsedValueAttr =
906 parser.parseOptionalAttribute(valueAttr);
907 if (!parsedValueAttr.has_value()) {
909 ParseResult parsedOperand = parser.parseOperand(operand);
910 if (failed(parsedOperand))
911 return failure(); // NB: Attempted parse should've output error message.
912 // To make use of the operand, we need to store it in the options dict.
913 // As SSA-values cannot occur in attributes, what we do instead is store
914 // an attribute in its place that contains the index of the param-operand,
915 // so that an attr-value associated to the param can be resolved later on.
916 dynamicOptions.push_back(operand);
917 auto wrappedIndex = IntegerAttr::get(
918 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
919 valueAttr =
920 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
921 } else if (failed(parsedValueAttr.value())) {
922 return failure(); // NB: Attempted parse should have output error message.
923 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
924 return parser.emitError(parser.getCurrentLocation())
925 << "the param_operand attribute is a marker reserved for "
926 << "indicating a value will be passed via params and is only used "
927 << "in the generic print format";
928 }
929
930 return success();
931 };
932
933 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
934 // string and `value` looks like either an attribute or an operand-in-an-attr.
935 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
936 std::string key;
937 Attribute valueAttr;
938
939 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
940 return parser.emitError(parser.getCurrentLocation())
941 << "expected key to either be an identifier or a string";
942
943 if (failed(parser.parseEqual()))
944 return parser.emitError(parser.getCurrentLocation())
945 << "expected '=' after key in key-value pair";
946
947 if (failed(parseValue(valueAttr)))
948 return parser.emitError(parser.getCurrentLocation())
949 << "expected a valid attribute or operand as value associated "
950 << "to key '" << key << "'";
951
952 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
953
954 return success();
955 };
956
959 " in options dictionary"))
960 return failure(); // NB: Attempted parse should have output error message.
961
962 if (DictionaryAttr::findDuplicate(
963 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
964 .has_value())
965 return parser.emitError(parser.getCurrentLocation())
966 << "duplicate keys found in options dictionary";
967
968 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
969
970 return success();
971}
972
974 Operation *op,
975 DictionaryAttr options,
976 ValueRange dynamicOptions) {
977 if (options.empty())
978 return;
979
980 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
981 if (auto paramOperandAttr =
982 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
983 // Resolve index of param-operand to its actual SSA-value and print that.
984 printer.printOperand(
985 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
986 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
987 // This case is so that ArrayAttr-contained operands are pretty-printed.
988 printer << "[";
989 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
990 printer << "]";
991 } else {
992 printer.printAttribute(valueAttr);
993 }
994 };
995
996 printer << "{";
997 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
998 printer << namedAttribute.getName();
999 printer << " = ";
1000 printOptionValue(namedAttribute.getValue());
1001 });
1002 printer << "}";
1003}
1004
1005LogicalResult transform::ApplyRegisteredPassOp::verify() {
1006 // Check that there is a one-to-one correspondence between param operands
1007 // and references to dynamic options in the options dictionary.
1008
1009 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1010
1011 // Helper for option values to mark seen operands as having been seen (once).
1012 std::function<LogicalResult(Attribute)> checkOptionValue =
1013 [&](Attribute valueAttr) -> LogicalResult {
1014 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1015 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1016 if (dynamicOptionIdx < 0 ||
1017 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1018 return emitOpError()
1019 << "dynamic option index " << dynamicOptionIdx
1020 << " is out of bounds for the number of dynamic options: "
1021 << dynamicOptions.size();
1022 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1023 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1024 << " is already used in options";
1025 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1026 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1027 // Recurse into ArrayAttrs as they may contain references to operands.
1028 for (auto eltAttr : arrayAttr)
1029 if (failed(checkOptionValue(eltAttr)))
1030 return failure();
1031 }
1032 return success();
1033 };
1034
1035 for (NamedAttribute namedAttr : getOptions())
1036 if (failed(checkOptionValue(namedAttr.getValue())))
1037 return failure();
1038
1039 // All dynamicOptions-params seen in the dict will have been set to null.
1040 for (Value dynamicOption : dynamicOptions)
1041 if (dynamicOption)
1042 return emitOpError() << "a param operand does not have a corresponding "
1043 << "param_operand attr in the options dict";
1044
1045 return success();
1046}
1047
1048//===----------------------------------------------------------------------===//
1049// CastOp
1050//===----------------------------------------------------------------------===//
1051
1053transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1054 Operation *target, ApplyToEachResultList &results,
1056 results.push_back(target);
1058}
1059
1060void transform::CastOp::getEffects(
1062 onlyReadsPayload(effects);
1063 onlyReadsHandle(getInputMutable(), effects);
1064 producesHandle(getOperation()->getOpResults(), effects);
1065}
1066
1067bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1068 assert(inputs.size() == 1 && "expected one input");
1069 assert(outputs.size() == 1 && "expected one output");
1070 return llvm::all_of(
1071 std::initializer_list<Type>{inputs.front(), outputs.front()},
1072 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1073}
1074
1075//===----------------------------------------------------------------------===//
1076// CollectMatchingOp
1077//===----------------------------------------------------------------------===//
1078
1079/// Applies matcher operations from the given `block` using
1080/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1081/// accordingly. If any of the matcher produces a silenceable failure, discards
1082/// it (printing the content to the debug output stream) and returns failure. If
1083/// any of the matchers produces a definite failure, reports it and returns
1084/// failure. If all matchers in the block succeed, populates `mappings` with the
1085/// payload entities associated with the block terminator operands. Note that
1086/// `mappings` will be cleared before that.
1089 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1092 assert(block.getParent() && "cannot match using a detached block");
1093 auto matchScope = state.make_region_scope(*block.getParent());
1094 if (failed(
1095 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1097
1098 for (Operation &match : block.without_terminator()) {
1099 if (!isa<transform::MatchOpInterface>(match)) {
1100 return emitDefiniteFailure(match.getLoc())
1101 << "expected operations in the match part to "
1102 "implement MatchOpInterface";
1103 }
1105 state.applyTransform(cast<transform::TransformOpInterface>(match));
1106 if (diag.succeeded())
1107 continue;
1108
1109 return diag;
1110 }
1111
1112 // Remember the values mapped to the terminator operands so we can
1113 // forward them to the action.
1114 ValueRange yieldedValues = block.getTerminator()->getOperands();
1115 // Our contract with the caller is that the mappings will contain only the
1116 // newly mapped values, clear the rest.
1117 mappings.clear();
1118 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1120}
1121
1122/// Returns `true` if both types implement one of the interfaces provided as
1123/// template parameters.
1124template <typename... Tys>
1125static bool implementSameInterface(Type t1, Type t2) {
1126 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1127}
1128
1129/// Returns `true` if both types implement one of the transform dialect
1130/// interfaces.
1132 return implementSameInterface<transform::TransformHandleTypeInterface,
1133 transform::TransformParamTypeInterface,
1134 transform::TransformValueHandleTypeInterface>(
1135 t1, t2);
1136}
1137
1138//===----------------------------------------------------------------------===//
1139// CollectMatchingOp
1140//===----------------------------------------------------------------------===//
1141
1143transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1147 getOperation(), getMatcher());
1148 if (matcher.isExternal()) {
1149 return emitDefiniteFailure()
1150 << "unresolved external symbol " << getMatcher();
1151 }
1152
1154 rawResults.resize(getOperation()->getNumResults());
1155 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1156 for (Operation *root : state.getPayloadOps(getRoot())) {
1157 WalkResult walkResult = root->walk([&](Operation *op) {
1158 LDBG(DEBUG_TYPE_MATCHER, 1)
1159 << "matching "
1160 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1161 << " @" << op;
1162
1163 // Try matching.
1165 SmallVector<transform::MappedValue> inputMapping({op});
1167 matcher.getFunctionBody().front(),
1168 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1169 mappings);
1170 if (diag.isDefiniteFailure())
1171 return WalkResult::interrupt();
1172 if (diag.isSilenceableFailure()) {
1173 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1174 << " failed: " << diag.getMessage();
1175 return WalkResult::advance();
1176 }
1177
1178 // If succeeded, collect results.
1179 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1180 if (mapping.size() != 1) {
1181 maybeFailure.emplace(emitSilenceableError()
1182 << "result #" << i << ", associated with "
1183 << mapping.size()
1184 << " payload objects, expected 1");
1185 return WalkResult::interrupt();
1186 }
1187 rawResults[i].push_back(mapping[0]);
1188 }
1189 return WalkResult::advance();
1190 });
1191 if (walkResult.wasInterrupted())
1192 return std::move(*maybeFailure);
1193 assert(!maybeFailure && "failure set but the walk was not interrupted");
1194
1195 for (auto &&[opResult, rawResult] :
1196 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1197 results.setMappedValues(opResult, rawResult);
1198 }
1199 }
1201}
1202
1203void transform::CollectMatchingOp::getEffects(
1205 onlyReadsHandle(getRootMutable(), effects);
1206 producesHandle(getOperation()->getOpResults(), effects);
1207 onlyReadsPayload(effects);
1208}
1209
1210LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1211 SymbolTableCollection &symbolTable) {
1212 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1213 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1214 if (!matcherSymbol ||
1215 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1216 return emitError() << "unresolved matcher symbol " << getMatcher();
1217
1218 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1219 if (argumentTypes.size() != 1 ||
1220 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1221 return emitError()
1222 << "expected the matcher to take one operation handle argument";
1223 }
1224 if (!matcherSymbol.getArgAttr(
1225 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1226 return emitError() << "expected the matcher argument to be marked readonly";
1227 }
1228
1229 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1230 if (resultTypes.size() != getOperation()->getNumResults()) {
1231 return emitError()
1232 << "expected the matcher to yield as many values as op has results ("
1233 << getOperation()->getNumResults() << "), got "
1234 << resultTypes.size();
1235 }
1236
1237 for (auto &&[i, matcherType, resultType] :
1238 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1239 if (implementSameTransformInterface(matcherType, resultType))
1240 continue;
1241
1242 return emitError()
1243 << "mismatching type interfaces for matcher result and op result #"
1244 << i;
1245 }
1246
1247 return success();
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// ForeachMatchOp
1252//===----------------------------------------------------------------------===//
1253
1254// This is fine because nothing is actually consumed by this op.
1255bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1256
1258transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1262 matchActionPairs;
1263 matchActionPairs.reserve(getMatchers().size());
1264 SymbolTableCollection symbolTable;
1265 for (auto &&[matcher, action] :
1266 llvm::zip_equal(getMatchers(), getActions())) {
1267 auto matcherSymbol =
1268 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1269 getOperation(), cast<SymbolRefAttr>(matcher));
1270 auto actionSymbol =
1271 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1272 getOperation(), cast<SymbolRefAttr>(action));
1273 assert(matcherSymbol && actionSymbol &&
1274 "unresolved symbols not caught by the verifier");
1275
1276 if (matcherSymbol.isExternal())
1277 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1278 if (actionSymbol.isExternal())
1279 return emitDefiniteFailure() << "unresolved external symbol " << action;
1280
1281 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1282 }
1283
1284 DiagnosedSilenceableFailure overallDiag =
1286
1287 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1288 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1289 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1290 // Explicitly add the mapping for the first block argument (the op being
1291 // matched).
1292 matchInputMapping.emplace_back();
1294 getForwardedInputs(), state);
1295 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1296 actionResultMapping.resize(getForwardedOutputs().size());
1297
1298 for (Operation *root : state.getPayloadOps(getRoot())) {
1299 WalkResult walkResult = root->walk([&](Operation *op) {
1300 // If getRestrictRoot is not present, skip over the root op itself so we
1301 // don't invalidate it.
1302 if (!getRestrictRoot() && op == root)
1303 return WalkResult::advance();
1304
1305 LDBG(DEBUG_TYPE_MATCHER, 1)
1306 << "matching "
1307 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1308 << " @" << op;
1309
1310 firstMatchArgument.clear();
1311 firstMatchArgument.push_back(op);
1312
1313 // Try all the match/action pairs until the first successful match.
1314 for (auto [matcher, action] : matchActionPairs) {
1316 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1317 state, matchOutputMapping);
1318 if (diag.isDefiniteFailure())
1319 return WalkResult::interrupt();
1320 if (diag.isSilenceableFailure()) {
1321 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1322 << " failed: " << diag.getMessage();
1323 continue;
1324 }
1325
1326 auto scope = state.make_region_scope(action.getFunctionBody());
1327 if (failed(state.mapBlockArguments(
1328 action.getFunctionBody().front().getArguments(),
1329 matchOutputMapping))) {
1330 return WalkResult::interrupt();
1331 }
1332
1333 for (Operation &transform :
1334 action.getFunctionBody().front().without_terminator()) {
1336 state.applyTransform(cast<TransformOpInterface>(transform));
1337 if (result.isDefiniteFailure())
1338 return WalkResult::interrupt();
1339 if (result.isSilenceableFailure()) {
1340 if (overallDiag.succeeded()) {
1341 overallDiag = emitSilenceableError() << "actions failed";
1342 }
1343 overallDiag.attachNote(action->getLoc())
1344 << "failed action: " << result.getMessage();
1345 overallDiag.attachNote(op->getLoc())
1346 << "when applied to this matching payload";
1347 (void)result.silence();
1348 continue;
1349 }
1350 }
1351 if (failed(detail::appendValueMappings(
1352 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1353 action.getFunctionBody().front().getTerminator()->getOperands(),
1354 state, getFlattenResults()))) {
1356 << "action @" << action.getName()
1357 << " has results associated with multiple payload entities, "
1358 "but flattening was not requested";
1359 return WalkResult::interrupt();
1360 }
1361 break;
1362 }
1363 return WalkResult::advance();
1364 });
1365 if (walkResult.wasInterrupted())
1367 }
1368
1369 // The root operation should not have been affected, so we can just reassign
1370 // the payload to the result. Note that we need to consume the root handle to
1371 // make sure any handles to operations inside, that could have been affected
1372 // by actions, are invalidated.
1373 results.set(llvm::cast<OpResult>(getUpdated()),
1374 state.getPayloadOps(getRoot()));
1375 for (auto &&[result, mapping] :
1376 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1377 results.setMappedValues(result, mapping);
1378 }
1379 return overallDiag;
1380}
1381
1382void transform::ForeachMatchOp::getAsmResultNames(
1383 OpAsmSetValueNameFn setNameFn) {
1384 setNameFn(getUpdated(), "updated_root");
1385 for (Value v : getForwardedOutputs()) {
1386 setNameFn(v, "yielded");
1387 }
1388}
1389
1390void transform::ForeachMatchOp::getEffects(
1392 // Bail if invalid.
1393 if (getOperation()->getNumOperands() < 1 ||
1394 getOperation()->getNumResults() < 1) {
1395 return modifiesPayload(effects);
1396 }
1397
1398 consumesHandle(getRootMutable(), effects);
1399 onlyReadsHandle(getForwardedInputsMutable(), effects);
1400 producesHandle(getOperation()->getOpResults(), effects);
1401 modifiesPayload(effects);
1402}
1403
1404/// Parses the comma-separated list of symbol reference pairs of the format
1405/// `@matcher -> @action`.
1406static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1408 ArrayAttr &actions) {
1409 StringAttr matcher;
1410 StringAttr action;
1411 SmallVector<Attribute> matcherList;
1412 SmallVector<Attribute> actionList;
1413 do {
1414 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1415 parser.parseSymbolName(action)) {
1416 return failure();
1417 }
1418 matcherList.push_back(SymbolRefAttr::get(matcher));
1419 actionList.push_back(SymbolRefAttr::get(action));
1420 } while (parser.parseOptionalComma().succeeded());
1421
1422 matchers = parser.getBuilder().getArrayAttr(matcherList);
1423 actions = parser.getBuilder().getArrayAttr(actionList);
1424 return success();
1425}
1426
1427/// Prints the comma-separated list of symbol reference pairs of the format
1428/// `@matcher -> @action`.
1430 ArrayAttr matchers, ArrayAttr actions) {
1431 printer.increaseIndent();
1432 printer.increaseIndent();
1433 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1434 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1435 printer.printNewline();
1436 printer << cast<SymbolRefAttr>(matcher) << " -> "
1437 << cast<SymbolRefAttr>(action);
1438 if (idx != matchers.size() - 1)
1439 printer << ", ";
1440 }
1441 printer.decreaseIndent();
1442 printer.decreaseIndent();
1443}
1444
1445LogicalResult transform::ForeachMatchOp::verify() {
1446 if (getMatchers().size() != getActions().size())
1447 return emitOpError() << "expected the same number of matchers and actions";
1448 if (getMatchers().empty())
1449 return emitOpError() << "expected at least one match/action pair";
1450
1452 for (Attribute name : getMatchers()) {
1453 if (matcherNames.insert(name).second)
1454 continue;
1455 emitWarning() << "matcher " << name
1456 << " is used more than once, only the first match will apply";
1457 }
1458
1459 return success();
1460}
1461
1462/// Checks that the attributes of the function-like operation have correct
1463/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1464/// annotations being present even if they can be inferred from the body.
1466verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1467 bool alsoVerifyInternal = false) {
1468 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1469 llvm::SmallDenseSet<unsigned> consumedArguments;
1470 if (!op.isExternal()) {
1471 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1472 consumedArguments);
1473 }
1474 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1475 bool isConsumed =
1476 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1477 nullptr;
1478 bool isReadOnly =
1479 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1480 nullptr;
1481 if (isConsumed && isReadOnly) {
1482 return transformOp.emitSilenceableError()
1483 << "argument #" << i << " cannot be both readonly and consumed";
1484 }
1485 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1486 return transformOp.emitSilenceableError()
1487 << "must provide consumed/readonly status for arguments of "
1488 "external or called ops";
1489 }
1490 if (op.isExternal())
1491 continue;
1492
1493 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1494 return transformOp.emitSilenceableError()
1495 << "argument #" << i
1496 << " is consumed in the body but is not marked as such";
1497 }
1498 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1499 // Cannot use op.emitWarning() here as it would attempt to verify the op
1500 // before printing, resulting in infinite recursion.
1501 emitWarning(op->getLoc())
1502 << "op argument #" << i
1503 << " is not consumed in the body but is marked as consumed";
1504 }
1505 }
1507}
1508
1509LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1510 SymbolTableCollection &symbolTable) {
1511 assert(getMatchers().size() == getActions().size());
1512 auto consumedAttr =
1513 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1514 for (auto &&[matcher, action] :
1515 llvm::zip_equal(getMatchers(), getActions())) {
1516 // Presence and typing.
1517 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1518 symbolTable.lookupNearestSymbolFrom(getOperation(),
1519 cast<SymbolRefAttr>(matcher)));
1520 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1521 symbolTable.lookupNearestSymbolFrom(getOperation(),
1522 cast<SymbolRefAttr>(action)));
1523 if (!matcherSymbol ||
1524 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1525 return emitError() << "unresolved matcher symbol " << matcher;
1526 if (!actionSymbol ||
1527 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1528 return emitError() << "unresolved action symbol " << action;
1529
1531 /*emitWarnings=*/false,
1532 /*alsoVerifyInternal=*/true)
1533 .checkAndReport())) {
1534 return failure();
1535 }
1537 /*emitWarnings=*/false,
1538 /*alsoVerifyInternal=*/true)
1539 .checkAndReport())) {
1540 return failure();
1541 }
1542
1543 // Input -> matcher forwarding.
1544 TypeRange operandTypes = getOperandTypes();
1545 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1546 if (operandTypes.size() != matcherArguments.size()) {
1548 emitError() << "the number of operands (" << operandTypes.size()
1549 << ") doesn't match the number of matcher arguments ("
1550 << matcherArguments.size() << ") for " << matcher;
1551 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1552 return diag;
1553 }
1554 for (auto &&[i, operand, argument] :
1555 llvm::enumerate(operandTypes, matcherArguments)) {
1556 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1558 emitOpError()
1559 << "does not expect matcher symbol to consume its operand #" << i;
1560 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1561 return diag;
1562 }
1563
1564 if (implementSameTransformInterface(operand, argument))
1565 continue;
1566
1568 emitError()
1569 << "mismatching type interfaces for operand and matcher argument #"
1570 << i << " of matcher " << matcher;
1571 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1572 return diag;
1573 }
1574
1575 // Matcher -> action forwarding.
1576 TypeRange matcherResults = matcherSymbol.getResultTypes();
1577 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1578 if (matcherResults.size() != actionArguments.size()) {
1579 return emitError() << "mismatching number of matcher results and "
1580 "action arguments between "
1581 << matcher << " (" << matcherResults.size() << ") and "
1582 << action << " (" << actionArguments.size() << ")";
1583 }
1584 for (auto &&[i, matcherType, actionType] :
1585 llvm::enumerate(matcherResults, actionArguments)) {
1586 if (implementSameTransformInterface(matcherType, actionType))
1587 continue;
1588
1589 return emitError() << "mismatching type interfaces for matcher result "
1590 "and action argument #"
1591 << i << "of matcher " << matcher << " and action "
1592 << action;
1593 }
1594
1595 // Action -> result forwarding.
1596 TypeRange actionResults = actionSymbol.getResultTypes();
1597 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1598 if (actionResults.size() != resultTypes.size()) {
1600 emitError() << "the number of action results ("
1601 << actionResults.size() << ") for " << action
1602 << " doesn't match the number of extra op results ("
1603 << resultTypes.size() << ")";
1604 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1605 return diag;
1606 }
1607 for (auto &&[i, resultType, actionType] :
1608 llvm::enumerate(resultTypes, actionResults)) {
1609 if (implementSameTransformInterface(resultType, actionType))
1610 continue;
1611
1613 emitError() << "mismatching type interfaces for action result #" << i
1614 << " of action " << action << " and op result";
1615 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1616 return diag;
1617 }
1618 }
1619 return success();
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// ForeachOp
1624//===----------------------------------------------------------------------===//
1625
1627transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1630 // We store the payloads before executing the body as ops may be removed from
1631 // the mapping by the TrackingRewriter while iteration is in progress.
1633 detail::prepareValueMappings(payloads, getTargets(), state);
1634 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1635 bool withZipShortest = getWithZipShortest();
1636
1637 // In case of `zip_shortest`, set the number of iterations to the
1638 // smallest payload in the targets.
1639 if (withZipShortest) {
1640 numIterations =
1641 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1642 const SmallVector<MappedValue> &b) {
1643 return a.size() < b.size();
1644 })->size();
1645
1646 for (auto &payload : payloads)
1647 payload.resize(numIterations);
1648 }
1649
1650 // As we will be "zipping" over them, check all payloads have the same size.
1651 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1652 // when true.
1653 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1654 argIdx++) {
1655 if (payloads[argIdx].size() != numIterations) {
1656 return emitSilenceableError()
1657 << "prior targets' payload size (" << numIterations
1658 << ") differs from payload size (" << payloads[argIdx].size()
1659 << ") of target " << getTargets()[argIdx];
1660 }
1661 }
1662
1663 // Start iterating, indexing into payloads to obtain the right arguments to
1664 // call the body with - each slice of payloads at the same argument index
1665 // corresponding to a tuple to use as the body's block arguments.
1666 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1667 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1668 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1669 auto scope = state.make_region_scope(getBody());
1670 // Set up arguments to the region's block.
1671 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1672 MappedValue argument = payloads[argIdx][iterIdx];
1673 // Note that each blockArg's handle gets associated with just a single
1674 // element from the corresponding target's payload.
1675 if (failed(state.mapBlockArgument(blockArg, {argument})))
1677 }
1678
1679 // Execute loop body.
1680 for (Operation &transform : getBody().front().without_terminator()) {
1682 llvm::cast<transform::TransformOpInterface>(transform));
1683 if (!result.succeeded())
1684 return result;
1685 }
1686
1687 // Append yielded payloads to corresponding results from prior iterations.
1688 OperandRange yieldOperands = getYieldOp().getOperands();
1689 for (auto &&[result, yieldOperand, resTuple] :
1690 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1691 // NB: each iteration we add any number of ops/vals/params to a result.
1692 if (isa<TransformHandleTypeInterface>(result.getType()))
1693 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1694 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1695 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1696 else if (isa<TransformParamTypeInterface>(result.getType()))
1697 llvm::append_range(resTuple, state.getParams(yieldOperand));
1698 else
1699 assert(false && "unhandled handle type");
1700 }
1701
1702 // Associate the accumulated result payloads to the op's actual results.
1703 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1704 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1705
1707}
1708
1709void transform::ForeachOp::getEffects(
1711 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1712 // arity errors, this method might get called before/in absence of `verify()`.
1713 for (auto &&[target, blockArg] :
1714 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1715 BlockArgument blockArgument = blockArg;
1716 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1717 return isHandleConsumed(blockArgument,
1718 cast<TransformOpInterface>(&op));
1719 })) {
1720 consumesHandle(target, effects);
1721 } else {
1722 onlyReadsHandle(target, effects);
1723 }
1724 }
1725
1726 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1727 return doesModifyPayload(cast<TransformOpInterface>(&op));
1728 })) {
1729 modifiesPayload(effects);
1730 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1731 return doesReadPayload(cast<TransformOpInterface>(&op));
1732 })) {
1733 onlyReadsPayload(effects);
1734 }
1735
1736 producesHandle(getOperation()->getOpResults(), effects);
1737}
1738
1739void transform::ForeachOp::getSuccessorRegions(
1741 Region *bodyRegion = &getBody();
1742 if (point.isParent()) {
1743 regions.emplace_back(bodyRegion);
1744 return;
1745 }
1746
1747 // Branch back to the region or the parent.
1748 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
1749 &getBody() &&
1750 "unexpected region index");
1751 regions.emplace_back(bodyRegion);
1752 regions.push_back(RegionSuccessor::parent());
1753}
1754
1755ValueRange transform::ForeachOp::getSuccessorInputs(RegionSuccessor successor) {
1756 return successor.isParent() ? ValueRange(getResults())
1757 : ValueRange(getBody().getArguments());
1758}
1759
1761transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1762 // Each block argument handle is mapped to a subset (one op to be precise)
1763 // of the payload of the corresponding `targets` operand of ForeachOp.
1764 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1765 return getOperation()->getOperands();
1766}
1767
1768transform::YieldOp transform::ForeachOp::getYieldOp() {
1769 return cast<transform::YieldOp>(getBody().front().getTerminator());
1770}
1771
1772LogicalResult transform::ForeachOp::verify() {
1773 for (auto [targetOpt, bodyArgOpt] :
1774 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1775 if (!targetOpt || !bodyArgOpt)
1776 return emitOpError() << "expects the same number of targets as the body "
1777 "has block arguments";
1778 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1779 return emitOpError(
1780 "expects co-indexed targets and the body's "
1781 "block arguments to have the same op/value/param type");
1782 }
1783
1784 for (auto [resultOpt, yieldOperandOpt] :
1785 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1786 if (!resultOpt || !yieldOperandOpt)
1787 return emitOpError() << "expects the same number of results as the "
1788 "yield terminator has operands";
1789 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1790 return emitOpError("expects co-indexed results and yield "
1791 "operands to have the same op/value/param type");
1792 }
1793
1794 return success();
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// GetParentOp
1799//===----------------------------------------------------------------------===//
1800
1802transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1806 DenseSet<Operation *> resultSet;
1807 for (Operation *target : state.getPayloadOps(getTarget())) {
1808 Operation *parent = target;
1809 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1810 parent = parent->getParentOp();
1811 while (parent) {
1812 bool checkIsolatedFromAbove =
1813 !getIsolatedFromAbove() ||
1815 bool checkOpName = !getOpName().has_value() ||
1816 parent->getName().getStringRef() == *getOpName();
1817 if (checkIsolatedFromAbove && checkOpName)
1818 break;
1819 parent = parent->getParentOp();
1820 }
1821 if (!parent) {
1822 if (getAllowEmptyResults()) {
1823 results.set(llvm::cast<OpResult>(getResult()), parents);
1825 }
1827 emitSilenceableError()
1828 << "could not find a parent op that matches all requirements";
1829 diag.attachNote(target->getLoc()) << "target op";
1830 return diag;
1831 }
1832 }
1833 if (getDeduplicate()) {
1834 if (resultSet.insert(parent).second)
1835 parents.push_back(parent);
1836 } else {
1837 parents.push_back(parent);
1838 }
1839 }
1840 results.set(llvm::cast<OpResult>(getResult()), parents);
1842}
1843
1844//===----------------------------------------------------------------------===//
1845// GetConsumersOfResult
1846//===----------------------------------------------------------------------===//
1847
1849transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1852 int64_t resultNumber = getResultNumber();
1853 auto payloadOps = state.getPayloadOps(getTarget());
1854 if (std::empty(payloadOps)) {
1855 results.set(cast<OpResult>(getResult()), {});
1857 }
1858 if (!llvm::hasSingleElement(payloadOps))
1859 return emitDefiniteFailure()
1860 << "handle must be mapped to exactly one payload op";
1861
1862 Operation *target = *payloadOps.begin();
1863 if (target->getNumResults() <= resultNumber)
1864 return emitDefiniteFailure() << "result number overflow";
1865 results.set(llvm::cast<OpResult>(getResult()),
1866 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1868}
1869
1870//===----------------------------------------------------------------------===//
1871// GetDefiningOp
1872//===----------------------------------------------------------------------===//
1873
1875transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1878 SmallVector<Operation *> definingOps;
1879 for (Value v : state.getPayloadValues(getTarget())) {
1880 if (llvm::isa<BlockArgument>(v)) {
1882 emitSilenceableError() << "cannot get defining op of block argument";
1883 diag.attachNote(v.getLoc()) << "target value";
1884 return diag;
1885 }
1886 definingOps.push_back(v.getDefiningOp());
1887 }
1888 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1890}
1891
1892//===----------------------------------------------------------------------===//
1893// GetProducerOfOperand
1894//===----------------------------------------------------------------------===//
1895
1897transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1900 int64_t operandNumber = getOperandNumber();
1901 SmallVector<Operation *> producers;
1902 for (Operation *target : state.getPayloadOps(getTarget())) {
1903 Operation *producer =
1904 target->getNumOperands() <= operandNumber
1905 ? nullptr
1906 : target->getOperand(operandNumber).getDefiningOp();
1907 if (!producer) {
1909 emitSilenceableError()
1910 << "could not find a producer for operand number: " << operandNumber
1911 << " of " << *target;
1912 diag.attachNote(target->getLoc()) << "target op";
1913 return diag;
1914 }
1915 producers.push_back(producer);
1916 }
1917 results.set(llvm::cast<OpResult>(getResult()), producers);
1919}
1920
1921//===----------------------------------------------------------------------===//
1922// GetOperandOp
1923//===----------------------------------------------------------------------===//
1924
1926transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1929 SmallVector<Value> operands;
1930 for (Operation *target : state.getPayloadOps(getTarget())) {
1931 SmallVector<int64_t> operandPositions;
1933 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1934 target->getNumOperands(), operandPositions);
1935 if (diag.isSilenceableFailure()) {
1936 diag.attachNote(target->getLoc())
1937 << "while considering positions of this payload operation";
1938 return diag;
1939 }
1940 llvm::append_range(operands,
1941 llvm::map_range(operandPositions, [&](int64_t pos) {
1942 return target->getOperand(pos);
1943 }));
1944 }
1945 results.setValues(cast<OpResult>(getResult()), operands);
1947}
1948
1949LogicalResult transform::GetOperandOp::verify() {
1950 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1951 getIsInverted(), getIsAll());
1952}
1953
1954//===----------------------------------------------------------------------===//
1955// GetResultOp
1956//===----------------------------------------------------------------------===//
1957
1959transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1962 SmallVector<Value> opResults;
1963 for (Operation *target : state.getPayloadOps(getTarget())) {
1964 SmallVector<int64_t> resultPositions;
1966 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1967 target->getNumResults(), resultPositions);
1968 if (diag.isSilenceableFailure()) {
1969 diag.attachNote(target->getLoc())
1970 << "while considering positions of this payload operation";
1971 return diag;
1972 }
1973 llvm::append_range(opResults,
1974 llvm::map_range(resultPositions, [&](int64_t pos) {
1975 return target->getResult(pos);
1976 }));
1977 }
1978 results.setValues(cast<OpResult>(getResult()), opResults);
1980}
1981
1982LogicalResult transform::GetResultOp::verify() {
1983 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1984 getIsInverted(), getIsAll());
1985}
1986
1987//===----------------------------------------------------------------------===//
1988// GetTypeOp
1989//===----------------------------------------------------------------------===//
1990
1991void transform::GetTypeOp::getEffects(
1993 onlyReadsHandle(getValueMutable(), effects);
1994 producesHandle(getOperation()->getOpResults(), effects);
1995 onlyReadsPayload(effects);
1996}
1997
1999transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
2003 for (Value value : state.getPayloadValues(getValue())) {
2004 Type type = value.getType();
2005 if (getElemental()) {
2006 if (auto shaped = dyn_cast<ShapedType>(type)) {
2007 type = shaped.getElementType();
2008 }
2009 }
2010 params.push_back(TypeAttr::get(type));
2011 }
2012 results.setParams(cast<OpResult>(getResult()), params);
2014}
2015
2016//===----------------------------------------------------------------------===//
2017// IncludeOp
2018//===----------------------------------------------------------------------===//
2019
2020/// Applies the transform ops contained in `block`. Maps `results` to the same
2021/// values as the operands of the block terminator.
2023applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2025 transform::TransformResults &results) {
2026 // Apply the sequenced ops one by one.
2027 for (Operation &transform : block.without_terminator()) {
2029 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2030 if (result.isDefiniteFailure())
2031 return result;
2032
2033 if (result.isSilenceableFailure()) {
2034 if (mode == transform::FailurePropagationMode::Propagate) {
2035 // Propagate empty results in case of early exit.
2036 forwardEmptyOperands(&block, state, results);
2037 return result;
2038 }
2039 (void)result.silence();
2040 }
2041 }
2042
2043 // Forward the operation mapping for values yielded from the sequence to the
2044 // values produced by the sequence op.
2045 transform::detail::forwardTerminatorOperands(&block, state, results);
2047}
2048
2050transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2054 getOperation(), getTarget());
2055 assert(callee && "unverified reference to unknown symbol");
2056
2057 if (callee.isExternal())
2058 return emitDefiniteFailure() << "unresolved external named sequence";
2059
2060 // Map operands to block arguments.
2062 detail::prepareValueMappings(mappings, getOperands(), state);
2063 auto scope = state.make_region_scope(callee.getBody());
2064 for (auto &&[arg, map] :
2065 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2066 if (failed(state.mapBlockArgument(arg, map)))
2068 }
2069
2071 callee.getBody().front(), getFailurePropagationMode(), state, results);
2072
2073 if (!result.succeeded())
2074 return result;
2075
2076 mappings.clear();
2077 detail::prepareValueMappings(
2078 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2079 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2080 results.setMappedValues(result, mapping);
2081 return result;
2082}
2083
2085verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2086
2087void transform::IncludeOp::getEffects(
2089 // Always mark as modifying the payload.
2090 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2091 // only read, the payload may still be modified, so we currently stay on the
2092 // conservative side and always indicate modification. This may prevent some
2093 // code reordering.
2094 modifiesPayload(effects);
2095
2096 // Results are always produced.
2097 producesHandle(getOperation()->getOpResults(), effects);
2098
2099 // Adds default effects to operands and results. This will be added if
2100 // preconditions fail so the trait verifier doesn't complain about missing
2101 // effects and the real precondition failure is reported later on.
2102 auto defaultEffects = [&] {
2103 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2104 };
2105
2106 // Bail if the callee is unknown. This may run as part of the verification
2107 // process before we verified the validity of the callee or of this op.
2108 auto target =
2109 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2110 if (!target)
2111 return defaultEffects();
2113 getOperation(), getTarget());
2114 if (!callee)
2115 return defaultEffects();
2116
2117 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2118 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2119 consumesHandle(getOperation()->getOpOperand(i), effects);
2120 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2121 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2122 }
2123}
2124
2125LogicalResult
2126transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2127 // Access through indirection and do additional checking because this may be
2128 // running before the main op verifier.
2129 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2130 if (!targetAttr)
2131 return emitOpError() << "expects a 'target' symbol reference attribute";
2132
2133 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2134 *this, targetAttr);
2135 if (!target)
2136 return emitOpError() << "does not reference a named transform sequence";
2137
2138 FunctionType fnType = target.getFunctionType();
2139 if (fnType.getNumInputs() != getNumOperands())
2140 return emitError("incorrect number of operands for callee");
2141
2142 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2143 if (getOperand(i).getType() != fnType.getInput(i)) {
2144 return emitOpError("operand type mismatch: expected operand type ")
2145 << fnType.getInput(i) << ", but provided "
2146 << getOperand(i).getType() << " for operand number " << i;
2147 }
2148 }
2149
2150 if (fnType.getNumResults() != getNumResults())
2151 return emitError("incorrect number of results for callee");
2152
2153 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2154 Type resultType = getResult(i).getType();
2155 Type funcType = fnType.getResult(i);
2156 if (!implementSameTransformInterface(resultType, funcType)) {
2157 return emitOpError() << "type of result #" << i
2158 << " must implement the same transform dialect "
2159 "interface as the corresponding callee result";
2160 }
2161 }
2162
2164 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2165 /*alsoVerifyInternal=*/true)
2166 .checkAndReport();
2167}
2168
2169//===----------------------------------------------------------------------===//
2170// MatchOperationEmptyOp
2171//===----------------------------------------------------------------------===//
2172
2173DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2174 ::std::optional<::mlir::Operation *> maybeCurrent,
2176 if (!maybeCurrent.has_value()) {
2177 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2179 }
2180 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2181 return emitSilenceableError() << "operation is not empty";
2182}
2183
2184//===----------------------------------------------------------------------===//
2185// MatchOperationNameOp
2186//===----------------------------------------------------------------------===//
2187
2188DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2189 Operation *current, transform::TransformResults &results,
2191 StringRef currentOpName = current->getName().getStringRef();
2192 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2193 if (acceptedAttr.getValue() == currentOpName)
2195 }
2196 return emitSilenceableError() << "wrong operation name";
2197}
2198
2199//===----------------------------------------------------------------------===//
2200// MatchParamCmpIOp
2201//===----------------------------------------------------------------------===//
2202
2204transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2207 auto signedAPIntAsString = [&](const APInt &value) {
2208 std::string str;
2209 llvm::raw_string_ostream os(str);
2210 value.print(os, /*isSigned=*/true);
2211 return str;
2212 };
2213
2214 ArrayRef<Attribute> params = state.getParams(getParam());
2215 ArrayRef<Attribute> references = state.getParams(getReference());
2216
2217 if (params.size() != references.size()) {
2218 return emitSilenceableError()
2219 << "parameters have different payload lengths (" << params.size()
2220 << " vs " << references.size() << ")";
2221 }
2222
2223 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2224 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2225 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2226 if (!intAttr || !refAttr) {
2227 return emitDefiniteFailure()
2228 << "non-integer parameter value not expected";
2229 }
2230 if (intAttr.getType() != refAttr.getType()) {
2231 return emitDefiniteFailure()
2232 << "mismatching integer attribute types in parameter #" << i;
2233 }
2234 APInt value = intAttr.getValue();
2235 APInt refValue = refAttr.getValue();
2236
2237 // TODO: this copy will not be necessary in C++20.
2238 int64_t position = i;
2239 auto reportError = [&](StringRef direction) {
2241 emitSilenceableError() << "expected parameter to be " << direction
2242 << " " << signedAPIntAsString(refValue)
2243 << ", got " << signedAPIntAsString(value);
2244 diag.attachNote(getParam().getLoc())
2245 << "value # " << position
2246 << " associated with the parameter defined here";
2247 return diag;
2248 };
2249
2250 switch (getPredicate()) {
2251 case MatchCmpIPredicate::eq:
2252 if (value.eq(refValue))
2253 break;
2254 return reportError("equal to");
2255 case MatchCmpIPredicate::ne:
2256 if (value.ne(refValue))
2257 break;
2258 return reportError("not equal to");
2259 case MatchCmpIPredicate::lt:
2260 if (value.slt(refValue))
2261 break;
2262 return reportError("less than");
2263 case MatchCmpIPredicate::le:
2264 if (value.sle(refValue))
2265 break;
2266 return reportError("less than or equal to");
2267 case MatchCmpIPredicate::gt:
2268 if (value.sgt(refValue))
2269 break;
2270 return reportError("greater than");
2271 case MatchCmpIPredicate::ge:
2272 if (value.sge(refValue))
2273 break;
2274 return reportError("greater than or equal to");
2275 }
2276 }
2278}
2279
2280void transform::MatchParamCmpIOp::getEffects(
2282 onlyReadsHandle(getParamMutable(), effects);
2283 onlyReadsHandle(getReferenceMutable(), effects);
2284}
2285
2286//===----------------------------------------------------------------------===//
2287// ParamConstantOp
2288//===----------------------------------------------------------------------===//
2289
2291transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2294 results.setParams(cast<OpResult>(getParam()), {getValue()});
2296}
2297
2298//===----------------------------------------------------------------------===//
2299// MergeHandlesOp
2300//===----------------------------------------------------------------------===//
2301
2303transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2306 ValueRange handles = getHandles();
2307 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2308 SmallVector<Operation *> operations;
2309 for (Value operand : handles)
2310 llvm::append_range(operations, state.getPayloadOps(operand));
2311 if (!getDeduplicate()) {
2312 results.set(llvm::cast<OpResult>(getResult()), operations);
2314 }
2315
2316 SetVector<Operation *> uniqued(llvm::from_range, operations);
2317 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2319 }
2320
2321 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2323 for (Value attribute : handles)
2324 llvm::append_range(attrs, state.getParams(attribute));
2325 if (!getDeduplicate()) {
2326 results.setParams(cast<OpResult>(getResult()), attrs);
2328 }
2329
2330 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2331 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2333 }
2334
2335 assert(
2336 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2337 "expected value handle type");
2338 SmallVector<Value> payloadValues;
2339 for (Value value : handles)
2340 llvm::append_range(payloadValues, state.getPayloadValues(value));
2341 if (!getDeduplicate()) {
2342 results.setValues(cast<OpResult>(getResult()), payloadValues);
2344 }
2345
2346 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2347 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2349}
2350
2351bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2352 // Handles may be the same if deduplicating is enabled.
2353 return getDeduplicate();
2354}
2355
2356void transform::MergeHandlesOp::getEffects(
2358 onlyReadsHandle(getHandlesMutable(), effects);
2359 producesHandle(getOperation()->getOpResults(), effects);
2360
2361 // There are no effects on the Payload IR as this is only a handle
2362 // manipulation.
2363}
2364
2365OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2366 if (getDeduplicate() || getHandles().size() != 1)
2367 return {};
2368
2369 // If deduplication is not required and there is only one operand, it can be
2370 // used directly instead of merging.
2371 return getHandles().front();
2372}
2373
2374//===----------------------------------------------------------------------===//
2375// NamedSequenceOp
2376//===----------------------------------------------------------------------===//
2377
2379transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2382 if (isExternal())
2383 return emitDefiniteFailure() << "unresolved external named sequence";
2384
2385 // Map the entry block argument to the list of operations.
2386 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2387 // without attaching the interface / trait since that is tailored to a
2388 // dangling top-level op that does not get "called".
2389 auto scope = state.make_region_scope(getBody());
2390 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2391 state, this->getOperation(), getBody())))
2393
2394 return applySequenceBlock(getBody().front(),
2395 FailurePropagationMode::Propagate, state, results);
2396}
2397
2398void transform::NamedSequenceOp::getEffects(
2400
2401ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2404 parser, result, /*allowVariadic=*/false,
2405 getFunctionTypeAttrName(result.name),
2406 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2408 std::string &) { return builder.getFunctionType(inputs, results); },
2409 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2410}
2411
2412void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2414 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2415 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2416 getResAttrsAttrName());
2417}
2418
2419/// Verifies that a symbol function-like transform dialect operation has the
2420/// signature and the terminator that have conforming types, i.e., types
2421/// implementing the same transform dialect type interface. If `allowExternal`
2422/// is set, allow external symbols (declarations) and don't check the terminator
2423/// as it may not exist.
2425verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2426 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2429 << "cannot be defined inside another transform op";
2430 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2431 return diag;
2432 }
2433
2434 if (op.isExternal() || op.getFunctionBody().empty()) {
2435 if (allowExternal)
2437
2438 return emitSilenceableFailure(op) << "cannot be external";
2439 }
2440
2441 if (op.getFunctionBody().front().empty())
2442 return emitSilenceableFailure(op) << "expected a non-empty body block";
2443
2444 Operation *terminator = &op.getFunctionBody().front().back();
2445 if (!isa<transform::YieldOp>(terminator)) {
2447 << "expected '"
2448 << transform::YieldOp::getOperationName()
2449 << "' as terminator";
2450 diag.attachNote(terminator->getLoc()) << "terminator";
2451 return diag;
2452 }
2453
2454 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2455 return emitSilenceableFailure(terminator)
2456 << "expected terminator to have as many operands as the parent op "
2457 "has results";
2458 }
2459 for (auto [i, operandType, resultType] : llvm::zip_equal(
2460 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2461 terminator->getOperands().getType(), op.getResultTypes())) {
2462 if (operandType == resultType)
2463 continue;
2464 return emitSilenceableFailure(terminator)
2465 << "the type of the terminator operand #" << i
2466 << " must match the type of the corresponding parent op result ("
2467 << operandType << " vs " << resultType << ")";
2468 }
2469
2471}
2472
2473/// Verification of a NamedSequenceOp. This does not report the error
2474/// immediately, so it can be used to check for op's well-formedness before the
2475/// verifier runs, e.g., during trait verification.
2477verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2478 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2479 if (!parent->getAttr(
2480 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2483 << "expects the parent symbol table to have the '"
2484 << transform::TransformDialect::kWithNamedSequenceAttrName
2485 << "' attribute";
2486 diag.attachNote(parent->getLoc()) << "symbol table operation";
2487 return diag;
2488 }
2489 }
2490
2491 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2494 << "cannot be defined inside another transform op";
2495 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2496 return diag;
2497 }
2498
2499 if (op.isExternal() || op.getBody().empty())
2500 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2501 emitWarnings);
2502
2503 if (op.getBody().front().empty())
2504 return emitSilenceableFailure(op) << "expected a non-empty body block";
2505
2506 // Check that all operations in the body implement TransformOpInterface
2507 for (Operation &child : op.getBody().front().without_terminator()) {
2508 if (!isa<transform::TransformOpInterface>(child)) {
2511 << "expected children ops to implement TransformOpInterface";
2512 diag.attachNote(child.getLoc()) << "op without interface";
2513 return diag;
2514 }
2515 }
2516
2517 Operation *terminator = &op.getBody().front().back();
2518 if (!isa<transform::YieldOp>(terminator)) {
2520 << "expected '"
2521 << transform::YieldOp::getOperationName()
2522 << "' as terminator";
2523 diag.attachNote(terminator->getLoc()) << "terminator";
2524 return diag;
2525 }
2526
2527 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2528 return emitSilenceableFailure(terminator)
2529 << "expected terminator to have as many operands as the parent op "
2530 "has results";
2531 }
2532 for (auto [i, operandType, resultType] :
2533 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2534 terminator->getOperands().getType(),
2535 op.getFunctionType().getResults())) {
2536 if (operandType == resultType)
2537 continue;
2538 return emitSilenceableFailure(terminator)
2539 << "the type of the terminator operand #" << i
2540 << " must match the type of the corresponding parent op result ("
2541 << operandType << " vs " << resultType << ")";
2542 }
2543
2544 auto funcOp = cast<FunctionOpInterface>(*op);
2546 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2547 if (!diag.succeeded())
2548 return diag;
2549
2550 return verifyYieldingSingleBlockOp(funcOp,
2551 /*allowExternal=*/true);
2552}
2553
2554LogicalResult transform::NamedSequenceOp::verify() {
2555 // Actual verification happens in a separate function for reusability.
2556 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2557}
2558
2559template <typename FnTy>
2560static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2561 Type bbArgType, TypeRange extraBindingTypes,
2562 FnTy bodyBuilder) {
2563 SmallVector<Type> types;
2564 types.reserve(1 + extraBindingTypes.size());
2565 types.push_back(bbArgType);
2566 llvm::append_range(types, extraBindingTypes);
2567
2568 OpBuilder::InsertionGuard guard(builder);
2569 Region *region = state.regions.back().get();
2570 Block *bodyBlock =
2571 builder.createBlock(region, region->begin(), types,
2572 SmallVector<Location>(types.size(), state.location));
2573
2574 // Populate body.
2575 builder.setInsertionPointToStart(bodyBlock);
2576 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2577 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2578 } else {
2579 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2580 bodyBlock->getArguments().drop_front());
2581 }
2582}
2583
2584void transform::NamedSequenceOp::build(OpBuilder &builder,
2585 OperationState &state, StringRef symName,
2586 Type rootType, TypeRange resultTypes,
2587 SequenceBodyBuilderFn bodyBuilder,
2589 ArrayRef<DictionaryAttr> argAttrs) {
2591 builder.getStringAttr(symName));
2592 state.addAttribute(getFunctionTypeAttrName(state.name),
2593 TypeAttr::get(FunctionType::get(builder.getContext(),
2594 rootType, resultTypes)));
2595 state.attributes.append(attrs.begin(), attrs.end());
2596 state.addRegion();
2597
2598 buildSequenceBody(builder, state, rootType,
2599 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2600}
2601
2602//===----------------------------------------------------------------------===//
2603// NumAssociationsOp
2604//===----------------------------------------------------------------------===//
2605
2607transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2610 size_t numAssociations =
2612 .Case([&](TransformHandleTypeInterface opHandle) {
2613 return llvm::range_size(state.getPayloadOps(getHandle()));
2614 })
2615 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2616 return llvm::range_size(state.getPayloadValues(getHandle()));
2617 })
2618 .Case([&](TransformParamTypeInterface param) {
2619 return llvm::range_size(state.getParams(getHandle()));
2620 })
2621 .DefaultUnreachable("unknown kind of transform dialect type");
2622 results.setParams(cast<OpResult>(getNum()),
2623 rewriter.getI64IntegerAttr(numAssociations));
2625}
2626
2627LogicalResult transform::NumAssociationsOp::verify() {
2628 // Verify that the result type accepts an i64 attribute as payload.
2629 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2630 return resultType
2631 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2632 .checkAndReport();
2633}
2634
2635//===----------------------------------------------------------------------===//
2636// SelectOp
2637//===----------------------------------------------------------------------===//
2638
2640transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2644 auto payloadOps = state.getPayloadOps(getTarget());
2645 for (Operation *op : payloadOps) {
2646 if (op->getName().getStringRef() == getOpName())
2647 result.push_back(op);
2648 }
2649 results.set(cast<OpResult>(getResult()), result);
2651}
2652
2653//===----------------------------------------------------------------------===//
2654// SplitHandleOp
2655//===----------------------------------------------------------------------===//
2656
2657void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2658 Value target, int64_t numResultHandles) {
2659 result.addOperands(target);
2660 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2661}
2662
2664transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2667 int64_t numPayloads =
2669 .Case([&](TransformHandleTypeInterface x) {
2670 return llvm::range_size(state.getPayloadOps(getHandle()));
2671 })
2672 .Case([&](TransformValueHandleTypeInterface x) {
2673 return llvm::range_size(state.getPayloadValues(getHandle()));
2674 })
2675 .Case([&](TransformParamTypeInterface x) {
2676 return llvm::range_size(state.getParams(getHandle()));
2677 })
2678 .DefaultUnreachable("unknown transform dialect type interface");
2679
2680 auto produceNumOpsError = [&]() {
2681 return emitSilenceableError()
2682 << getHandle() << " expected to contain " << this->getNumResults()
2683 << " payloads but it contains " << numPayloads << " payloads";
2684 };
2685
2686 // Fail if there are more payload ops than results and no overflow result was
2687 // specified.
2688 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2689 return produceNumOpsError();
2690
2691 // Fail if there are more results than payload ops. Unless:
2692 // - "fail_on_payload_too_small" is set to "false", or
2693 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2694 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2695 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2696 return produceNumOpsError();
2697
2698 // Distribute payloads.
2699 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2700 if (getOverflowResult())
2701 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2702
2703 auto container = [&]() {
2704 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2705 return llvm::map_to_vector(
2706 state.getPayloadOps(getHandle()),
2707 [](Operation *op) -> MappedValue { return op; });
2708 }
2709 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2710 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2711 [](Value v) -> MappedValue { return v; });
2712 }
2713 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2714 "unsupported kind of transform dialect type");
2715 return llvm::map_to_vector(state.getParams(getHandle()),
2716 [](Attribute a) -> MappedValue { return a; });
2717 }();
2718
2719 for (auto &&en : llvm::enumerate(container)) {
2720 int64_t resultNum = en.index();
2721 if (resultNum >= getNumResults())
2722 resultNum = *getOverflowResult();
2723 resultHandles[resultNum].push_back(en.value());
2724 }
2725
2726 // Set transform op results.
2727 for (auto &&it : llvm::enumerate(resultHandles))
2728 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2729 it.value());
2730
2732}
2733
2734void transform::SplitHandleOp::getEffects(
2736 onlyReadsHandle(getHandleMutable(), effects);
2737 producesHandle(getOperation()->getOpResults(), effects);
2738 // There are no effects on the Payload IR as this is only a handle
2739 // manipulation.
2740}
2741
2742LogicalResult transform::SplitHandleOp::verify() {
2743 if (getOverflowResult().has_value() &&
2744 !(*getOverflowResult() < getNumResults()))
2745 return emitOpError("overflow_result is not a valid result index");
2746
2747 for (Type resultType : getResultTypes()) {
2748 if (implementSameTransformInterface(getHandle().getType(), resultType))
2749 continue;
2750
2751 return emitOpError("expects result types to implement the same transform "
2752 "interface as the operand type");
2753 }
2754
2755 return success();
2756}
2757
2758//===----------------------------------------------------------------------===//
2759// PayloadOp
2760//===----------------------------------------------------------------------===//
2761
2762void transform::PayloadOp::getCheckedNormalForms(
2764 llvm::append_range(normalForms,
2765 getNormalForms().getAsRange<NormalFormAttrInterface>());
2766}
2767
2768//===----------------------------------------------------------------------===//
2769// ReplicateOp
2770//===----------------------------------------------------------------------===//
2771
2773transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2776 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2777 for (const auto &en : llvm::enumerate(getHandles())) {
2778 Value handle = en.value();
2779 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2780 SmallVector<Operation *> current =
2781 llvm::to_vector(state.getPayloadOps(handle));
2783 payload.reserve(numRepetitions * current.size());
2784 for (unsigned i = 0; i < numRepetitions; ++i)
2785 llvm::append_range(payload, current);
2786 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2787 } else {
2788 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2789 "expected param type");
2790 ArrayRef<Attribute> current = state.getParams(handle);
2792 params.reserve(numRepetitions * current.size());
2793 for (unsigned i = 0; i < numRepetitions; ++i)
2794 llvm::append_range(params, current);
2795 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2796 params);
2797 }
2798 }
2800}
2801
2802void transform::ReplicateOp::getEffects(
2804 onlyReadsHandle(getPatternMutable(), effects);
2805 onlyReadsHandle(getHandlesMutable(), effects);
2806 producesHandle(getOperation()->getOpResults(), effects);
2807}
2808
2809//===----------------------------------------------------------------------===//
2810// SequenceOp
2811//===----------------------------------------------------------------------===//
2812
2814transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2817 // Map the entry block argument to the list of operations.
2818 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2819 if (failed(mapBlockArguments(state)))
2821
2822 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2823 results);
2824}
2825
2826static ParseResult parseSequenceOpOperands(
2827 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2828 Type &rootType,
2830 SmallVectorImpl<Type> &extraBindingTypes) {
2832 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2833 if (!hasRoot.has_value()) {
2834 root = std::nullopt;
2835 return success();
2836 }
2837 if (failed(hasRoot.value()))
2838 return failure();
2839 root = rootOperand;
2840
2841 if (succeeded(parser.parseOptionalComma())) {
2842 if (failed(parser.parseOperandList(extraBindings)))
2843 return failure();
2844 }
2845 if (failed(parser.parseColon()))
2846 return failure();
2847
2848 // The paren is truly optional.
2849 (void)parser.parseOptionalLParen();
2850
2851 if (failed(parser.parseType(rootType))) {
2852 return failure();
2853 }
2854
2855 if (!extraBindings.empty()) {
2856 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2857 return failure();
2858 }
2859
2860 if (extraBindingTypes.size() != extraBindings.size()) {
2861 return parser.emitError(parser.getNameLoc(),
2862 "expected types to be provided for all operands");
2863 }
2864
2865 // The paren is truly optional.
2866 (void)parser.parseOptionalRParen();
2867 return success();
2868}
2869
2871 Value root, Type rootType,
2872 ValueRange extraBindings,
2873 TypeRange extraBindingTypes) {
2874 if (!root)
2875 return;
2876
2877 printer << root;
2878 bool hasExtras = !extraBindings.empty();
2879 if (hasExtras) {
2880 printer << ", ";
2881 printer.printOperands(extraBindings);
2882 }
2883
2884 printer << " : ";
2885 if (hasExtras)
2886 printer << "(";
2887
2888 printer << rootType;
2889 if (hasExtras)
2890 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2891}
2892
2893/// Returns `true` if the given op operand may be consuming the handle value in
2894/// the Transform IR. That is, if it may have a Free effect on it.
2896 // Conservatively assume the effect being present in absence of the interface.
2897 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2898 if (!iface)
2899 return true;
2900
2901 return isHandleConsumed(use.get(), iface);
2902}
2903
2904static LogicalResult
2906 function_ref<InFlightDiagnostic()> reportError) {
2907 OpOperand *potentialConsumer = nullptr;
2908 for (OpOperand &use : value.getUses()) {
2910 continue;
2911
2912 if (!potentialConsumer) {
2913 potentialConsumer = &use;
2914 continue;
2915 }
2916
2917 InFlightDiagnostic diag = reportError()
2918 << " has more than one potential consumer";
2919 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2920 << "used here as operand #" << potentialConsumer->getOperandNumber();
2921 diag.attachNote(use.getOwner()->getLoc())
2922 << "used here as operand #" << use.getOperandNumber();
2923 return diag;
2924 }
2925
2926 return success();
2927}
2928
2929LogicalResult transform::SequenceOp::verify() {
2930 assert(getBodyBlock()->getNumArguments() >= 1 &&
2931 "the number of arguments must have been verified to be more than 1 by "
2932 "PossibleTopLevelTransformOpTrait");
2933
2934 if (!getRoot() && !getExtraBindings().empty()) {
2935 return emitOpError()
2936 << "does not expect extra operands when used as top-level";
2937 }
2938
2939 // Check if a block argument has more than one consuming use.
2940 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2941 if (failed(checkDoubleConsume(arg, [this, arg]() {
2942 return (emitOpError() << "block argument #" << arg.getArgNumber());
2943 }))) {
2944 return failure();
2945 }
2946 }
2947
2948 // Check properties of the nested operations they cannot check themselves.
2949 for (Operation &child : *getBodyBlock()) {
2950 if (!isa<TransformOpInterface>(child) &&
2951 &child != &getBodyBlock()->back()) {
2953 emitOpError()
2954 << "expected children ops to implement TransformOpInterface";
2955 diag.attachNote(child.getLoc()) << "op without interface";
2956 return diag;
2957 }
2958
2959 for (OpResult result : child.getResults()) {
2960 auto report = [&]() {
2961 return (child.emitError() << "result #" << result.getResultNumber());
2962 };
2963 if (failed(checkDoubleConsume(result, report)))
2964 return failure();
2965 }
2966 }
2967
2968 if (!getBodyBlock()->mightHaveTerminator())
2969 return emitOpError() << "expects to have a terminator in the body";
2970
2971 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2972 getOperation()->getResultTypes()) {
2974 << "expects the types of the terminator operands "
2975 "to match the types of the result";
2976 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2977 return diag;
2978 }
2979 return success();
2980}
2981
2982void transform::SequenceOp::getEffects(
2985}
2986
2988transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2989 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2990 if (getOperation()->getNumOperands() > 0)
2991 return getOperation()->getOperands();
2992 return OperandRange(getOperation()->operand_end(),
2993 getOperation()->operand_end());
2994}
2995
2996void transform::SequenceOp::getSuccessorRegions(
2998 if (point.isParent()) {
2999 Region *bodyRegion = &getBody();
3000 regions.emplace_back(bodyRegion);
3001 return;
3002 }
3003
3004 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
3005 &getBody() &&
3006 "unexpected region index");
3007 regions.push_back(RegionSuccessor::parent());
3008}
3009
3011transform::SequenceOp::getSuccessorInputs(RegionSuccessor successor) {
3012 if (getNumOperands() == 0)
3013 return ValueRange();
3014 if (successor.isParent())
3015 return getResults();
3016 return getBody().getArguments();
3017}
3018
3019void transform::SequenceOp::getRegionInvocationBounds(
3021 (void)operands;
3022 bounds.emplace_back(1, 1);
3023}
3024
3025void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3026 TypeRange resultTypes,
3027 FailurePropagationMode failurePropagationMode,
3028 Value root,
3029 SequenceBodyBuilderFn bodyBuilder) {
3030 build(builder, state, resultTypes, failurePropagationMode, root,
3031 /*extra_bindings=*/ValueRange());
3032 Type bbArgType = root.getType();
3033 buildSequenceBody(builder, state, bbArgType,
3034 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3035}
3036
3037void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3038 TypeRange resultTypes,
3039 FailurePropagationMode failurePropagationMode,
3040 Value root, ValueRange extraBindings,
3041 SequenceBodyBuilderArgsFn bodyBuilder) {
3042 build(builder, state, resultTypes, failurePropagationMode, root,
3043 extraBindings);
3044 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3045 bodyBuilder);
3046}
3047
3048void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3049 TypeRange resultTypes,
3050 FailurePropagationMode failurePropagationMode,
3051 Type bbArgType,
3052 SequenceBodyBuilderFn bodyBuilder) {
3053 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3054 /*extra_bindings=*/ValueRange());
3055 buildSequenceBody(builder, state, bbArgType,
3056 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3057}
3058
3059void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3060 TypeRange resultTypes,
3061 FailurePropagationMode failurePropagationMode,
3062 Type bbArgType, TypeRange extraBindingTypes,
3063 SequenceBodyBuilderArgsFn bodyBuilder) {
3064 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3065 /*extra_bindings=*/ValueRange());
3066 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3067}
3068
3069//===----------------------------------------------------------------------===//
3070// PrintOp
3071//===----------------------------------------------------------------------===//
3072
3073void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3074 StringRef name) {
3075 if (!name.empty())
3076 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3077}
3078
3079void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3080 Value target, StringRef name) {
3081 result.addOperands({target});
3082 build(builder, result, name);
3083}
3084
3086transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3089 llvm::outs() << "[[[ IR printer: ";
3090 if (getName().has_value())
3091 llvm::outs() << *getName() << " ";
3092
3093 OpPrintingFlags printFlags;
3094 if (getAssumeVerified().value_or(false))
3095 printFlags.assumeVerified();
3096 if (getUseLocalScope().value_or(false))
3097 printFlags.useLocalScope();
3098 if (getSkipRegions().value_or(false))
3099 printFlags.skipRegions();
3100
3101 if (!getTarget()) {
3102 llvm::outs() << "top-level ]]]\n";
3103 state.getTopLevel()->print(llvm::outs(), printFlags);
3104 llvm::outs() << "\n";
3105 llvm::outs().flush();
3107 }
3108
3109 llvm::outs() << "]]]\n";
3110 for (Operation *target : state.getPayloadOps(getTarget())) {
3111 target->print(llvm::outs(), printFlags);
3112 llvm::outs() << "\n";
3113 }
3114
3115 llvm::outs().flush();
3117}
3118
3119void transform::PrintOp::getEffects(
3121 // We don't really care about mutability here, but `getTarget` now
3122 // unconditionally casts to a specific type before verification could run
3123 // here.
3124 if (!getTargetMutable().empty())
3125 onlyReadsHandle(getTargetMutable()[0], effects);
3126 onlyReadsPayload(effects);
3127
3128 // There is no resource for stderr file descriptor, so just declare print
3129 // writes into the default resource.
3130 effects.emplace_back(MemoryEffects::Write::get());
3131}
3132
3133//===----------------------------------------------------------------------===//
3134// VerifyOp
3135//===----------------------------------------------------------------------===//
3136
3138transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3144 << "failed to verify payload op";
3145 diag.attachNote(target->getLoc()) << "payload op";
3146 return diag;
3147 }
3149}
3150
3151void transform::VerifyOp::getEffects(
3153 transform::onlyReadsHandle(getTargetMutable(), effects);
3154}
3155
3156//===----------------------------------------------------------------------===//
3157// YieldOp
3158//===----------------------------------------------------------------------===//
3159
3160void transform::YieldOp::getEffects(
3162 onlyReadsHandle(getOperandsMutable(), effects);
3163}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
GreedyRewriteConfig & setListener(RewriterBase::Listener *listener)
GreedyRewriteConfig & enableCSEBetweenIterations(bool enable=true)
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setMaxNumRewrites(int64_t limit)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1143
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
unsigned getNumOperands()
Definition Operation.h:372
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getOpResults()
Definition Operation.h:446
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:418
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.