MLIR 23.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/ErrorHandling.h"
43#include "llvm/Support/InterleavedRange.h"
44#include <optional>
45
46#define DEBUG_TYPE "transform-dialect"
47#define DEBUG_TYPE_MATCHER "transform-matcher"
48
49using namespace mlir;
50
51static ParseResult parseApplyRegisteredPassOptions(
52 OpAsmParser &parser, DictionaryAttr &options,
55 Operation *op,
56 DictionaryAttr options,
57 ValueRange dynamicOptions);
58static ParseResult parseSequenceOpOperands(
59 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
60 Type &rootType,
62 SmallVectorImpl<Type> &extraBindingTypes);
63static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
64 Value root, Type rootType,
65 ValueRange extraBindings,
66 TypeRange extraBindingTypes);
67static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
69static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
71 ArrayAttr &actions);
72
73/// Helper function to check if the given transform op is contained in (or
74/// equal to) the given payload target op. In that case, an error is returned.
75/// Transforming transform IR that is currently executing is generally unsafe.
77ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
78 Operation *payload) {
79 Operation *transformAncestor = transform.getOperation();
80 while (transformAncestor) {
81 if (transformAncestor == payload) {
83 transform.emitDefiniteFailure()
84 << "cannot apply transform to itself (or one of its ancestors)";
85 diag.attachNote(payload->getLoc()) << "target payload op";
86 return diag;
87 }
88 transformAncestor = transformAncestor->getParentOp();
89 }
91}
92
93#define GET_OP_CLASSES
94#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
95
96//===----------------------------------------------------------------------===//
97// AlternativesOp
98//===----------------------------------------------------------------------===//
99
100OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
101 RegionSuccessor successor) {
102 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
103 return getOperation()->getOperands();
104 return OperandRange(getOperation()->operand_end(),
105 getOperation()->operand_end());
106}
107
108void transform::AlternativesOp::getSuccessorRegions(
110 for (Region &alternative : llvm::drop_begin(
111 getAlternatives(), point.isParent()
112 ? 0
114 ->getParentRegion()
115 ->getRegionNumber() +
116 1)) {
117 regions.emplace_back(&alternative);
118 }
119 if (!point.isParent())
120 regions.push_back(RegionSuccessor::parent());
121}
122
124transform::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) {
125 if (successor.isParent())
126 return getOperation()->getResults();
127 return successor.getSuccessor()->getArguments();
128}
129
130void transform::AlternativesOp::getRegionInvocationBounds(
132 (void)operands;
133 // The region corresponding to the first alternative is always executed, the
134 // remaining may or may not be executed.
135 bounds.reserve(getNumRegions());
136 bounds.emplace_back(1, 1);
137 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
138}
139
142 for (const auto &res : block->getParentOp()->getOpResults())
143 results.set(res, {});
144}
145
147transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
150 SmallVector<Operation *> originals;
151 if (Value scopeHandle = getScope())
152 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
153 else
154 originals.push_back(state.getTopLevel());
155
156 for (Operation *original : originals) {
157 if (original->isAncestor(getOperation())) {
159 << "scope must not contain the transforms being applied";
160 diag.attachNote(original->getLoc()) << "scope";
161 return diag;
162 }
163 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
165 << "only isolated-from-above ops can be alternative scopes";
166 diag.attachNote(original->getLoc()) << "scope";
167 return diag;
168 }
169 }
170
171 for (Region &reg : getAlternatives()) {
172 // Clone the scope operations and make the transforms in this alternative
173 // region apply to them by virtue of mapping the block argument (the only
174 // visible handle) to the cloned scope operations. This effectively prevents
175 // the transformation from accessing any IR outside the scope.
176 auto scope = state.make_region_scope(reg);
177 auto clones = llvm::map_to_vector(
178 originals, [](Operation *op) { return op->clone(); });
179 llvm::scope_exit deleteClones([&] {
180 for (Operation *clone : clones)
181 clone->erase();
182 });
183 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
185
186 bool failed = false;
187 for (Operation &transform : reg.front().without_terminator()) {
189 state.applyTransform(cast<TransformOpInterface>(transform));
190 if (result.isSilenceableFailure()) {
191 LDBG() << "alternative failed: " << result.getMessage();
192 failed = true;
193 break;
194 }
195
196 if (::mlir::failed(result.silence()))
198 }
199
200 // If all operations in the given alternative succeeded, no need to consider
201 // the rest. Replace the original scoping operation with the clone on which
202 // the transformations were performed.
203 if (!failed) {
204 // We will be using the clones, so cancel their scheduled deletion.
205 deleteClones.release();
206 TrackingListener listener(state, *this);
207 IRRewriter rewriter(getContext(), &listener);
208 for (const auto &kvp : llvm::zip(originals, clones)) {
209 Operation *original = std::get<0>(kvp);
210 Operation *clone = std::get<1>(kvp);
211 original->getBlock()->getOperations().insert(original->getIterator(),
212 clone);
213 rewriter.replaceOp(original, clone->getResults());
214 }
215 detail::forwardTerminatorOperands(&reg.front(), state, results);
217 }
218 }
219 return emitSilenceableError() << "all alternatives failed";
220}
221
222void transform::AlternativesOp::getEffects(
224 consumesHandle(getOperation()->getOpOperands(), effects);
225 producesHandle(getOperation()->getOpResults(), effects);
226 for (Region *region : getRegions()) {
227 if (!region->empty())
228 producesHandle(region->front().getArguments(), effects);
229 }
230 modifiesPayload(effects);
231}
232
233LogicalResult transform::AlternativesOp::verify() {
234 for (Region &alternative : getAlternatives()) {
235 Block &block = alternative.front();
236 Operation *terminator = block.getTerminator();
237 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
239 << "expects terminator operands to have the "
240 "same type as results of the operation";
241 diag.attachNote(terminator->getLoc()) << "terminator";
242 return diag;
243 }
244 }
245
246 return success();
247}
248
249//===----------------------------------------------------------------------===//
250// AnnotateOp
251//===----------------------------------------------------------------------===//
252
254transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
258 llvm::to_vector(state.getPayloadOps(getTarget()));
259
260 Attribute attr = UnitAttr::get(getContext());
261 if (auto paramH = getParam()) {
262 ArrayRef<Attribute> params = state.getParams(paramH);
263 if (params.size() != 1) {
264 if (targets.size() != params.size()) {
265 return emitSilenceableError()
266 << "parameter and target have different payload lengths ("
267 << params.size() << " vs " << targets.size() << ")";
268 }
269 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
270 target->setAttr(getName(), attr);
272 }
273 attr = params[0];
274 }
275 for (auto *target : targets)
276 target->setAttr(getName(), attr);
278}
279
280void transform::AnnotateOp::getEffects(
282 onlyReadsHandle(getTargetMutable(), effects);
283 onlyReadsHandle(getParamMutable(), effects);
284 modifiesPayload(effects);
285}
286
287//===----------------------------------------------------------------------===//
288// ApplyCommonSubexpressionEliminationOp
289//===----------------------------------------------------------------------===//
290
292transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
294 ApplyToEachResultList &results, transform::TransformState &state) {
295 // Make sure that this transform is not applied to itself. Modifying the
296 // transform IR while it is being interpreted is generally dangerous.
297 DiagnosedSilenceableFailure payloadCheck =
299 if (!payloadCheck.succeeded())
300 return payloadCheck;
301
302 DominanceInfo domInfo;
305}
306
307void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
309 transform::onlyReadsHandle(getTargetMutable(), effects);
311}
312
313//===----------------------------------------------------------------------===//
314// ApplyDeadCodeEliminationOp
315//===----------------------------------------------------------------------===//
316
317DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
319 ApplyToEachResultList &results, transform::TransformState &state) {
320 // Make sure that this transform is not applied to itself. Modifying the
321 // transform IR while it is being interpreted is generally dangerous.
322 DiagnosedSilenceableFailure payloadCheck =
324 if (!payloadCheck.succeeded())
325 return payloadCheck;
326
327 // Maintain a worklist of potentially dead ops.
328 SetVector<Operation *> worklist;
329
330 // Helper function that adds all defining ops of used values (operands and
331 // operands of nested ops).
332 auto addDefiningOpsToWorklist = [&](Operation *op) {
333 op->walk([&](Operation *op) {
334 for (Value v : op->getOperands())
335 if (Operation *defOp = v.getDefiningOp())
336 if (target->isProperAncestor(defOp))
337 worklist.insert(defOp);
338 });
339 };
340
341 // Helper function that erases an op.
342 auto eraseOp = [&](Operation *op) {
343 // Remove op and nested ops from the worklist.
344 op->walk([&](Operation *op) {
345 const auto *it = llvm::find(worklist, op);
346 if (it != worklist.end())
347 worklist.erase(it);
348 });
349 rewriter.eraseOp(op);
350 };
351
352 // Initial walk over the IR.
353 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
354 if (op != target && isOpTriviallyDead(op)) {
355 addDefiningOpsToWorklist(op);
356 eraseOp(op);
357 }
358 });
359
360 // Erase all ops that have become dead.
361 while (!worklist.empty()) {
362 Operation *op = worklist.pop_back_val();
363 if (!isOpTriviallyDead(op))
364 continue;
365 addDefiningOpsToWorklist(op);
366 eraseOp(op);
367 }
368
370}
371
372void transform::ApplyDeadCodeEliminationOp::getEffects(
374 transform::onlyReadsHandle(getTargetMutable(), effects);
376}
377
378//===----------------------------------------------------------------------===//
379// ApplyPatternsOp
380//===----------------------------------------------------------------------===//
381
382DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
384 ApplyToEachResultList &results, transform::TransformState &state) {
385 // Make sure that this transform is not applied to itself. Modifying the
386 // transform IR while it is being interpreted is generally dangerous. Even
387 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
388 // performs many additional simplifications such as dead code elimination.
389 DiagnosedSilenceableFailure payloadCheck =
391 if (!payloadCheck.succeeded())
392 return payloadCheck;
393
394 // Gather all specified patterns.
395 MLIRContext *ctx = target->getContext();
397 if (!getRegion().empty()) {
398 for (Operation &op : getRegion().front()) {
399 cast<transform::PatternDescriptorOpInterface>(&op)
400 .populatePatternsWithState(patterns, state);
401 }
402 }
403
404 // Configure the GreedyPatternRewriteDriver.
406 config.setListener(
407 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
408 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
409
410 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
412 : getMaxIterations());
413 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
415 : getMaxNumRewrites());
416
417 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
418 // was requested, apply the greedy pattern rewrite only once. (The greedy
419 // pattern rewrite driver already iterates to a fixpoint internally.)
420 bool cseChanged = false;
421 // One or two iterations should be sufficient. Stop iterating after a certain
422 // threshold to make debugging easier.
423 static const int64_t kNumMaxIterations = 50;
424 int64_t iteration = 0;
425 do {
426 LogicalResult result = failure();
427 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
428 // Op is isolated from above. Apply patterns and also perform region
429 // simplification.
430 result = applyPatternsGreedily(target, frozenPatterns, config);
431 } else {
432 // Manually gather list of ops because the other
433 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
434 // from above. This way, patterns can be applied to ops that are not
435 // isolated from above. Regions are not being simplified. Furthermore,
436 // only a single greedy rewrite iteration is performed.
438 target->walk([&](Operation *nestedOp) {
439 if (target != nestedOp)
440 ops.push_back(nestedOp);
441 });
442 result = applyOpPatternsGreedily(ops, frozenPatterns, config);
443 }
444
445 // A failure typically indicates that the pattern application did not
446 // converge.
447 if (failed(result)) {
449 << "greedy pattern application failed";
450 }
451
452 if (getApplyCse()) {
453 DominanceInfo domInfo;
455 &cseChanged);
456 }
457 } while (cseChanged && ++iteration < kNumMaxIterations);
458
459 if (iteration == kNumMaxIterations)
460 return emitDefiniteFailure() << "fixpoint iteration did not converge";
461
463}
464
465LogicalResult transform::ApplyPatternsOp::verify() {
466 if (!getRegion().empty()) {
467 for (Operation &op : getRegion().front()) {
468 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
470 << "expected children ops to implement "
471 "PatternDescriptorOpInterface";
472 diag.attachNote(op.getLoc()) << "op without interface";
473 return diag;
474 }
475 }
476 }
477 return success();
478}
479
480void transform::ApplyPatternsOp::getEffects(
482 transform::onlyReadsHandle(getTargetMutable(), effects);
484}
485
486void transform::ApplyPatternsOp::build(
488 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
489 result.addOperands(target);
490
491 OpBuilder::InsertionGuard g(builder);
492 Region *region = result.addRegion();
493 builder.createBlock(region);
494 if (bodyBuilder)
495 bodyBuilder(builder, result.location);
496}
497
498//===----------------------------------------------------------------------===//
499// ApplyCanonicalizationPatternsOp
500//===----------------------------------------------------------------------===//
501
502void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
504 MLIRContext *ctx = patterns.getContext();
505 for (Dialect *dialect : ctx->getLoadedDialects())
506 dialect->getCanonicalizationPatterns(patterns);
508 op.getCanonicalizationPatterns(patterns, ctx);
509}
510
511//===----------------------------------------------------------------------===//
512// ApplyConversionPatternsOp
513//===----------------------------------------------------------------------===//
514
515DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
518 MLIRContext *ctx = getContext();
519
520 // Instantiate the default type converter if a type converter builder is
521 // specified.
522 std::unique_ptr<TypeConverter> defaultTypeConverter;
523 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
524 getDefaultTypeConverter();
525 if (typeConverterBuilder)
526 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
527
528 // Configure conversion target.
529 ConversionTarget conversionTarget(*getContext());
530 if (getLegalOps())
531 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
532 conversionTarget.addLegalOp(
533 OperationName(cast<StringAttr>(attr).getValue(), ctx));
534 if (getIllegalOps())
535 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
536 conversionTarget.addIllegalOp(
537 OperationName(cast<StringAttr>(attr).getValue(), ctx));
538 if (getLegalDialects())
539 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
540 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
541 if (getIllegalDialects())
542 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
543 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
544
545 // Gather all specified patterns.
547 // Need to keep the converters alive until after pattern application because
548 // the patterns take a reference to an object that would otherwise get out of
549 // scope.
551 if (!getPatterns().empty()) {
552 for (Operation &op : getPatterns().front()) {
553 auto descriptor =
554 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
555
556 // Check if this pattern set specifies a type converter.
557 std::unique_ptr<TypeConverter> typeConverter =
558 descriptor.getTypeConverter();
559 TypeConverter *converter = nullptr;
560 if (typeConverter) {
561 keepAliveConverters.emplace_back(std::move(typeConverter));
562 converter = keepAliveConverters.back().get();
563 } else {
564 // No type converter specified: Use the default type converter.
565 if (!defaultTypeConverter) {
567 << "pattern descriptor does not specify type "
568 "converter and apply_conversion_patterns op has "
569 "no default type converter";
570 diag.attachNote(op.getLoc()) << "pattern descriptor op";
571 return diag;
572 }
573 converter = defaultTypeConverter.get();
574 }
575
576 // Add descriptor-specific updates to the conversion target, which may
577 // depend on the final type converter. In structural converters, the
578 // legality of types dictates the dynamic legality of an operation.
579 descriptor.populateConversionTargetRules(*converter, conversionTarget);
580
581 descriptor.populatePatterns(*converter, patterns);
582 }
583 }
584
585 // Attach a tracking listener if handles should be preserved. We configure the
586 // listener to allow op replacements with different names, as conversion
587 // patterns typically replace ops with replacement ops that have a different
588 // name.
589 TrackingListenerConfig trackingConfig;
590 trackingConfig.requireMatchingReplacementOpName = false;
591 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
592 ConversionConfig conversionConfig;
593 if (getPreserveHandles())
594 conversionConfig.listener = &trackingListener;
595
596 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
597 for (Operation *target : state.getPayloadOps(getTarget())) {
598 // Make sure that this transform is not applied to itself. Modifying the
599 // transform IR while it is being interpreted is generally dangerous.
600 DiagnosedSilenceableFailure payloadCheck =
602 if (!payloadCheck.succeeded())
603 return payloadCheck;
604
605 LogicalResult status = failure();
606 if (getPartialConversion()) {
607 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
608 conversionConfig);
609 } else {
610 status = applyFullConversion(target, conversionTarget, frozenPatterns,
611 conversionConfig);
612 }
613
614 // Check dialect conversion state.
616 if (failed(status)) {
617 diag = emitSilenceableError() << "dialect conversion failed";
618 diag.attachNote(target->getLoc()) << "target op";
619 }
620
621 // Check tracking listener error state.
622 DiagnosedSilenceableFailure trackingFailure =
623 trackingListener.checkAndResetError();
624 if (!trackingFailure.succeeded()) {
625 if (diag.succeeded()) {
626 // Tracking failure is the only failure.
627 return trackingFailure;
628 }
629 diag.attachNote() << "tracking listener also failed: "
630 << trackingFailure.getMessage();
631 (void)trackingFailure.silence();
632 }
633
634 if (!diag.succeeded())
635 return diag;
636 }
637
639}
640
641LogicalResult transform::ApplyConversionPatternsOp::verify() {
642 if (getNumRegions() != 1 && getNumRegions() != 2)
643 return emitOpError() << "expected 1 or 2 regions";
644 if (!getPatterns().empty()) {
645 for (Operation &op : getPatterns().front()) {
646 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
648 emitOpError() << "expected pattern children ops to implement "
649 "ConversionPatternDescriptorOpInterface";
650 diag.attachNote(op.getLoc()) << "op without interface";
651 return diag;
652 }
653 }
654 }
655 if (getNumRegions() == 2) {
656 Region &typeConverterRegion = getRegion(1);
657 if (!llvm::hasSingleElement(typeConverterRegion.front()))
658 return emitOpError()
659 << "expected exactly one op in default type converter region";
660 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
661 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
662 maybeTypeConverter);
663 if (!typeConverterOp) {
665 << "expected default converter child op to "
666 "implement TypeConverterBuilderOpInterface";
667 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
668 return diag;
669 }
670 // Check default type converter type.
671 if (!getPatterns().empty()) {
672 for (Operation &op : getPatterns().front()) {
673 auto descriptor =
674 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
675 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
676 return failure();
677 }
678 }
679 }
680 return success();
681}
682
683void transform::ApplyConversionPatternsOp::getEffects(
685 if (!getPreserveHandles()) {
686 transform::consumesHandle(getTargetMutable(), effects);
687 } else {
688 transform::onlyReadsHandle(getTargetMutable(), effects);
689 }
691}
692
693void transform::ApplyConversionPatternsOp::build(
695 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
696 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
697 result.addOperands(target);
698
699 {
700 OpBuilder::InsertionGuard g(builder);
701 Region *region1 = result.addRegion();
702 builder.createBlock(region1);
703 if (patternsBodyBuilder)
704 patternsBodyBuilder(builder, result.location);
705 }
706 {
707 OpBuilder::InsertionGuard g(builder);
708 Region *region2 = result.addRegion();
709 builder.createBlock(region2);
710 if (typeConverterBodyBuilder)
711 typeConverterBodyBuilder(builder, result.location);
712 }
713}
714
715//===----------------------------------------------------------------------===//
716// ApplyToLLVMConversionPatternsOp
717//===----------------------------------------------------------------------===//
718
719void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
720 TypeConverter &typeConverter, RewritePatternSet &patterns) {
721 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
722 assert(dialect && "expected that dialect is loaded");
723 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
724 // ConversionTarget is currently ignored because the enclosing
725 // apply_conversion_patterns op sets up its own ConversionTarget.
727 iface->populateConvertToLLVMConversionPatterns(
728 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
729}
730
731LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
732 transform::TypeConverterBuilderOpInterface builder) {
733 if (builder.getTypeConverterType() != "LLVMTypeConverter")
734 return emitOpError("expected LLVMTypeConverter");
735 return success();
736}
737
738LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
739 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
740 if (!dialect)
741 return emitOpError("unknown dialect or dialect not loaded: ")
742 << getDialectName();
743 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
744 if (!iface)
745 return emitOpError(
746 "dialect does not implement ConvertToLLVMPatternInterface or "
747 "extension was not loaded: ")
748 << getDialectName();
749 return success();
750}
751
752//===----------------------------------------------------------------------===//
753// ApplyLoopInvariantCodeMotionOp
754//===----------------------------------------------------------------------===//
755
757transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
758 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
761 // Currently, LICM does not remove operations, so we don't need tracking.
762 // If this ever changes, add a LICM entry point that takes a rewriter.
765}
766
767void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
769 transform::onlyReadsHandle(getTargetMutable(), effects);
771}
772
773//===----------------------------------------------------------------------===//
774// ApplyRegisteredPassOp
775//===----------------------------------------------------------------------===//
776
777void transform::ApplyRegisteredPassOp::getEffects(
779 consumesHandle(getTargetMutable(), effects);
780 onlyReadsHandle(getDynamicOptionsMutable(), effects);
781 producesHandle(getOperation()->getOpResults(), effects);
782 modifiesPayload(effects);
783}
784
786transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
789 // Obtain a single options-string to pass to the pass(-pipeline) from options
790 // passed in as a dictionary of keys mapping to values which are either
791 // attributes or param-operands pointing to attributes.
792 OperandRange dynamicOptions = getDynamicOptions();
793
794 std::string options;
795 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
796
797 // A helper to convert an option's attribute value into a corresponding
798 // string representation, with the ability to obtain the attr(s) from a param.
799 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
800 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
801 // The corresponding value attribute(s) is/are passed in via a param.
802 // Obtain the param-operand via its specified index.
803 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
804 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
805 "the number of ParamOperandAttrs in the options DictionaryAttr"
806 "should be the same as the number of options passed as params");
807 ArrayRef<Attribute> attrsAssociatedToParam =
808 state.getParams(dynamicOptions[dynamicOptionIdx]);
809 // Recursive so as to append all attrs associated to the param.
810 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
811 ",");
812 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
813 // Recursive so as to append all nested attrs of the array.
814 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
815 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
816 // Convert to unquoted string.
817 optionsStream << strAttr.getValue().str();
818 } else {
819 // For all other attributes, ask the attr to print itself (without type).
820 valueAttr.print(optionsStream, /*elideType=*/true);
821 }
822 };
823
824 // Convert the options DictionaryAttr into a single string.
825 llvm::interleave(
826 getOptions(), optionsStream,
827 [&](auto namedAttribute) {
828 optionsStream << namedAttribute.getName().str(); // Append the key.
829 optionsStream << "="; // And the key-value separator.
830 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
831 },
832 " ");
833 optionsStream.flush();
834
835 // Get pass or pass pipeline from registry.
836 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
837 if (!info)
838 info = PassInfo::lookup(getPassName());
839 if (!info)
840 return emitDefiniteFailure()
841 << "unknown pass or pass pipeline: " << getPassName();
842
843 // Create pass manager and add the pass or pass pipeline.
845 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
846 emitError(msg);
847 return failure();
848 }))) {
849 return emitDefiniteFailure()
850 << "failed to add pass or pass pipeline to pipeline: "
851 << getPassName();
852 }
853
854 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
855 for (Operation *target : targets) {
856 // Make sure that this transform is not applied to itself. Modifying the
857 // transform IR while it is being interpreted is generally dangerous. Even
858 // more so when applying passes because they may perform a wide range of IR
859 // modifications.
860 DiagnosedSilenceableFailure payloadCheck =
862 if (!payloadCheck.succeeded())
863 return payloadCheck;
864
865 // Run the pass or pass pipeline on the current target operation.
866 if (failed(pm.run(target))) {
867 auto diag = emitSilenceableError() << "pass pipeline failed";
868 diag.attachNote(target->getLoc()) << "target op";
869 return diag;
870 }
871 }
872
873 // The applied pass will have directly modified the payload IR(s).
874 results.set(llvm::cast<OpResult>(getResult()), targets);
876}
877
879 OpAsmParser &parser, DictionaryAttr &options,
881 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
882 SmallVector<NamedAttribute> keyValuePairs;
883 size_t dynamicOptionsIdx = 0;
884
885 // Helper for allowing parsing of option values which can be of the form:
886 // - a normal attribute
887 // - an operand (which would be converted to an attr referring to the operand)
888 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
889 std::function<ParseResult(Attribute &)> parseValue =
890 [&](Attribute &valueAttr) -> ParseResult {
891 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
892 if (succeeded(parser.parseOptionalLSquare())) {
894
895 // Recursively parse the array's elements, which might be operands.
896 if (parser.parseCommaSeparatedList(
898 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
899 " in options dictionary") ||
900 parser.parseRSquare())
901 return failure(); // NB: Attempted parse should've output error message.
902
903 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
904
905 return success();
906 }
907
908 // Parse the value, which can be either an attribute or an operand.
909 OptionalParseResult parsedValueAttr =
910 parser.parseOptionalAttribute(valueAttr);
911 if (!parsedValueAttr.has_value()) {
913 ParseResult parsedOperand = parser.parseOperand(operand);
914 if (failed(parsedOperand))
915 return failure(); // NB: Attempted parse should've output error message.
916 // To make use of the operand, we need to store it in the options dict.
917 // As SSA-values cannot occur in attributes, what we do instead is store
918 // an attribute in its place that contains the index of the param-operand,
919 // so that an attr-value associated to the param can be resolved later on.
920 dynamicOptions.push_back(operand);
921 auto wrappedIndex = IntegerAttr::get(
922 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
923 valueAttr =
924 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
925 } else if (failed(parsedValueAttr.value())) {
926 return failure(); // NB: Attempted parse should have output error message.
927 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
928 return parser.emitError(parser.getCurrentLocation())
929 << "the param_operand attribute is a marker reserved for "
930 << "indicating a value will be passed via params and is only used "
931 << "in the generic print format";
932 }
933
934 return success();
935 };
936
937 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
938 // string and `value` looks like either an attribute or an operand-in-an-attr.
939 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
940 std::string key;
941 Attribute valueAttr;
942
943 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
944 return parser.emitError(parser.getCurrentLocation())
945 << "expected key to either be an identifier or a string";
946
947 if (failed(parser.parseEqual()))
948 return parser.emitError(parser.getCurrentLocation())
949 << "expected '=' after key in key-value pair";
950
951 if (failed(parseValue(valueAttr)))
952 return parser.emitError(parser.getCurrentLocation())
953 << "expected a valid attribute or operand as value associated "
954 << "to key '" << key << "'";
955
956 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
957
958 return success();
959 };
960
963 " in options dictionary"))
964 return failure(); // NB: Attempted parse should have output error message.
965
966 if (DictionaryAttr::findDuplicate(
967 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
968 .has_value())
969 return parser.emitError(parser.getCurrentLocation())
970 << "duplicate keys found in options dictionary";
971
972 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
973
974 return success();
975}
976
978 Operation *op,
979 DictionaryAttr options,
980 ValueRange dynamicOptions) {
981 if (options.empty())
982 return;
983
984 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
985 if (auto paramOperandAttr =
986 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
987 // Resolve index of param-operand to its actual SSA-value and print that.
988 printer.printOperand(
989 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
990 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
991 // This case is so that ArrayAttr-contained operands are pretty-printed.
992 printer << "[";
993 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
994 printer << "]";
995 } else {
996 printer.printAttribute(valueAttr);
997 }
998 };
999
1000 printer << "{";
1001 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
1002 printer << namedAttribute.getName();
1003 printer << " = ";
1004 printOptionValue(namedAttribute.getValue());
1005 });
1006 printer << "}";
1007}
1008
1009LogicalResult transform::ApplyRegisteredPassOp::verify() {
1010 // Check that there is a one-to-one correspondence between param operands
1011 // and references to dynamic options in the options dictionary.
1012
1013 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1014
1015 // Helper for option values to mark seen operands as having been seen (once).
1016 std::function<LogicalResult(Attribute)> checkOptionValue =
1017 [&](Attribute valueAttr) -> LogicalResult {
1018 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1019 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1020 if (dynamicOptionIdx < 0 ||
1021 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1022 return emitOpError()
1023 << "dynamic option index " << dynamicOptionIdx
1024 << " is out of bounds for the number of dynamic options: "
1025 << dynamicOptions.size();
1026 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1027 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1028 << " is already used in options";
1029 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1030 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1031 // Recurse into ArrayAttrs as they may contain references to operands.
1032 for (auto eltAttr : arrayAttr)
1033 if (failed(checkOptionValue(eltAttr)))
1034 return failure();
1035 }
1036 return success();
1037 };
1038
1039 for (NamedAttribute namedAttr : getOptions())
1040 if (failed(checkOptionValue(namedAttr.getValue())))
1041 return failure();
1042
1043 // All dynamicOptions-params seen in the dict will have been set to null.
1044 for (Value dynamicOption : dynamicOptions)
1045 if (dynamicOption)
1046 return emitOpError() << "a param operand does not have a corresponding "
1047 << "param_operand attr in the options dict";
1048
1049 return success();
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// CastOp
1054//===----------------------------------------------------------------------===//
1055
1057transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1058 Operation *target, ApplyToEachResultList &results,
1060 results.push_back(target);
1062}
1063
1064void transform::CastOp::getEffects(
1066 onlyReadsPayload(effects);
1067 onlyReadsHandle(getInputMutable(), effects);
1068 producesHandle(getOperation()->getOpResults(), effects);
1069}
1070
1071bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1072 assert(inputs.size() == 1 && "expected one input");
1073 assert(outputs.size() == 1 && "expected one output");
1074 return llvm::all_of(
1075 std::initializer_list<Type>{inputs.front(), outputs.front()},
1076 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1077}
1078
1079//===----------------------------------------------------------------------===//
1080// CollectMatchingOp
1081//===----------------------------------------------------------------------===//
1082
1083/// Applies matcher operations from the given `block` using
1084/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1085/// accordingly. If any of the matcher produces a silenceable failure, discards
1086/// it (printing the content to the debug output stream) and returns failure. If
1087/// any of the matchers produces a definite failure, reports it and returns
1088/// failure. If all matchers in the block succeed, populates `mappings` with the
1089/// payload entities associated with the block terminator operands. Note that
1090/// `mappings` will be cleared before that.
1093 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1096 assert(block.getParent() && "cannot match using a detached block");
1097 auto matchScope = state.make_region_scope(*block.getParent());
1098 if (failed(
1099 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1101
1102 for (Operation &match : block.without_terminator()) {
1103 if (!isa<transform::MatchOpInterface>(match)) {
1104 return emitDefiniteFailure(match.getLoc())
1105 << "expected operations in the match part to "
1106 "implement MatchOpInterface";
1107 }
1109 state.applyTransform(cast<transform::TransformOpInterface>(match));
1110 if (diag.succeeded())
1111 continue;
1112
1113 return diag;
1114 }
1115
1116 // Remember the values mapped to the terminator operands so we can
1117 // forward them to the action.
1118 ValueRange yieldedValues = block.getTerminator()->getOperands();
1119 // Our contract with the caller is that the mappings will contain only the
1120 // newly mapped values, clear the rest.
1121 mappings.clear();
1122 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1124}
1125
1126/// Returns `true` if both types implement one of the interfaces provided as
1127/// template parameters.
1128template <typename... Tys>
1129static bool implementSameInterface(Type t1, Type t2) {
1130 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1131}
1132
1133/// Returns `true` if both types implement one of the transform dialect
1134/// interfaces.
1136 return implementSameInterface<transform::TransformHandleTypeInterface,
1137 transform::TransformParamTypeInterface,
1138 transform::TransformValueHandleTypeInterface>(
1139 t1, t2);
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// CollectMatchingOp
1144//===----------------------------------------------------------------------===//
1145
1147transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1151 getOperation(), getMatcher());
1152 if (matcher.isExternal()) {
1153 return emitDefiniteFailure()
1154 << "unresolved external symbol " << getMatcher();
1155 }
1156
1158 rawResults.resize(getOperation()->getNumResults());
1159 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1160 for (Operation *root : state.getPayloadOps(getRoot())) {
1161 WalkResult walkResult = root->walk([&](Operation *op) {
1162 LDBG(DEBUG_TYPE_MATCHER, 1)
1163 << "matching "
1164 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1165 << " @" << op;
1166
1167 // Try matching.
1169 SmallVector<transform::MappedValue> inputMapping({op});
1171 matcher.getFunctionBody().front(),
1172 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1173 mappings);
1174 if (diag.isDefiniteFailure())
1175 return WalkResult::interrupt();
1176 if (diag.isSilenceableFailure()) {
1177 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1178 << " failed: " << diag.getMessage();
1179 return WalkResult::advance();
1180 }
1181
1182 // If succeeded, collect results.
1183 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1184 if (mapping.size() != 1) {
1185 maybeFailure.emplace(emitSilenceableError()
1186 << "result #" << i << ", associated with "
1187 << mapping.size()
1188 << " payload objects, expected 1");
1189 return WalkResult::interrupt();
1190 }
1191 rawResults[i].push_back(mapping[0]);
1192 }
1193 return WalkResult::advance();
1194 });
1195 if (walkResult.wasInterrupted())
1196 return std::move(*maybeFailure);
1197 assert(!maybeFailure && "failure set but the walk was not interrupted");
1198
1199 for (auto &&[opResult, rawResult] :
1200 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1201 results.setMappedValues(opResult, rawResult);
1202 }
1203 }
1205}
1206
1207void transform::CollectMatchingOp::getEffects(
1209 onlyReadsHandle(getRootMutable(), effects);
1210 producesHandle(getOperation()->getOpResults(), effects);
1211 onlyReadsPayload(effects);
1212}
1213
1214LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1215 SymbolTableCollection &symbolTable) {
1216 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1217 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1218 if (!matcherSymbol ||
1219 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1220 return emitError() << "unresolved matcher symbol " << getMatcher();
1221
1222 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1223 if (argumentTypes.size() != 1 ||
1224 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1225 return emitError()
1226 << "expected the matcher to take one operation handle argument";
1227 }
1228 if (!matcherSymbol.getArgAttr(
1229 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1230 return emitError() << "expected the matcher argument to be marked readonly";
1231 }
1232
1233 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1234 if (resultTypes.size() != getOperation()->getNumResults()) {
1235 return emitError()
1236 << "expected the matcher to yield as many values as op has results ("
1237 << getOperation()->getNumResults() << "), got "
1238 << resultTypes.size();
1239 }
1240
1241 for (auto &&[i, matcherType, resultType] :
1242 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1243 if (implementSameTransformInterface(matcherType, resultType))
1244 continue;
1245
1246 return emitError()
1247 << "mismatching type interfaces for matcher result and op result #"
1248 << i;
1249 }
1250
1251 return success();
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// ForeachMatchOp
1256//===----------------------------------------------------------------------===//
1257
1258// This is fine because nothing is actually consumed by this op.
1259bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1260
1262transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1266 matchActionPairs;
1267 matchActionPairs.reserve(getMatchers().size());
1268 SymbolTableCollection symbolTable;
1269 for (auto &&[matcher, action] :
1270 llvm::zip_equal(getMatchers(), getActions())) {
1271 auto matcherSymbol =
1272 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1273 getOperation(), cast<SymbolRefAttr>(matcher));
1274 auto actionSymbol =
1275 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1276 getOperation(), cast<SymbolRefAttr>(action));
1277 assert(matcherSymbol && actionSymbol &&
1278 "unresolved symbols not caught by the verifier");
1279
1280 if (matcherSymbol.isExternal())
1281 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1282 if (actionSymbol.isExternal())
1283 return emitDefiniteFailure() << "unresolved external symbol " << action;
1284
1285 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1286 }
1287
1288 DiagnosedSilenceableFailure overallDiag =
1290
1291 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1292 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1293 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1294 // Explicitly add the mapping for the first block argument (the op being
1295 // matched).
1296 matchInputMapping.emplace_back();
1298 getForwardedInputs(), state);
1299 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1300 actionResultMapping.resize(getForwardedOutputs().size());
1301
1302 for (Operation *root : state.getPayloadOps(getRoot())) {
1303 WalkResult walkResult = root->walk([&](Operation *op) {
1304 // If getRestrictRoot is not present, skip over the root op itself so we
1305 // don't invalidate it.
1306 if (!getRestrictRoot() && op == root)
1307 return WalkResult::advance();
1308
1309 LDBG(DEBUG_TYPE_MATCHER, 1)
1310 << "matching "
1311 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1312 << " @" << op;
1313
1314 firstMatchArgument.clear();
1315 firstMatchArgument.push_back(op);
1316
1317 // Try all the match/action pairs until the first successful match.
1318 for (auto [matcher, action] : matchActionPairs) {
1320 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1321 state, matchOutputMapping);
1322 if (diag.isDefiniteFailure())
1323 return WalkResult::interrupt();
1324 if (diag.isSilenceableFailure()) {
1325 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1326 << " failed: " << diag.getMessage();
1327 continue;
1328 }
1329
1330 auto scope = state.make_region_scope(action.getFunctionBody());
1331 if (failed(state.mapBlockArguments(
1332 action.getFunctionBody().front().getArguments(),
1333 matchOutputMapping))) {
1334 return WalkResult::interrupt();
1335 }
1336
1337 for (Operation &transform :
1338 action.getFunctionBody().front().without_terminator()) {
1340 state.applyTransform(cast<TransformOpInterface>(transform));
1341 if (result.isDefiniteFailure())
1342 return WalkResult::interrupt();
1343 if (result.isSilenceableFailure()) {
1344 if (overallDiag.succeeded()) {
1345 overallDiag = emitSilenceableError() << "actions failed";
1346 }
1347 overallDiag.attachNote(action->getLoc())
1348 << "failed action: " << result.getMessage();
1349 overallDiag.attachNote(op->getLoc())
1350 << "when applied to this matching payload";
1351 (void)result.silence();
1352 continue;
1353 }
1354 }
1355 if (failed(detail::appendValueMappings(
1356 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1357 action.getFunctionBody().front().getTerminator()->getOperands(),
1358 state, getFlattenResults()))) {
1360 << "action @" << action.getName()
1361 << " has results associated with multiple payload entities, "
1362 "but flattening was not requested";
1363 return WalkResult::interrupt();
1364 }
1365 break;
1366 }
1367 return WalkResult::advance();
1368 });
1369 if (walkResult.wasInterrupted())
1371 }
1372
1373 // The root operation should not have been affected, so we can just reassign
1374 // the payload to the result. Note that we need to consume the root handle to
1375 // make sure any handles to operations inside, that could have been affected
1376 // by actions, are invalidated.
1377 results.set(llvm::cast<OpResult>(getUpdated()),
1378 state.getPayloadOps(getRoot()));
1379 for (auto &&[result, mapping] :
1380 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1381 results.setMappedValues(result, mapping);
1382 }
1383 return overallDiag;
1384}
1385
1386void transform::ForeachMatchOp::getAsmResultNames(
1387 OpAsmSetValueNameFn setNameFn) {
1388 setNameFn(getUpdated(), "updated_root");
1389 for (Value v : getForwardedOutputs()) {
1390 setNameFn(v, "yielded");
1391 }
1392}
1393
1394void transform::ForeachMatchOp::getEffects(
1396 // Bail if invalid.
1397 if (getOperation()->getNumOperands() < 1 ||
1398 getOperation()->getNumResults() < 1) {
1399 return modifiesPayload(effects);
1400 }
1401
1402 consumesHandle(getRootMutable(), effects);
1403 onlyReadsHandle(getForwardedInputsMutable(), effects);
1404 producesHandle(getOperation()->getOpResults(), effects);
1405 modifiesPayload(effects);
1406}
1407
1408/// Parses the comma-separated list of symbol reference pairs of the format
1409/// `@matcher -> @action`.
1410static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1412 ArrayAttr &actions) {
1413 StringAttr matcher;
1414 StringAttr action;
1415 SmallVector<Attribute> matcherList;
1416 SmallVector<Attribute> actionList;
1417 do {
1418 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1419 parser.parseSymbolName(action)) {
1420 return failure();
1421 }
1422 matcherList.push_back(SymbolRefAttr::get(matcher));
1423 actionList.push_back(SymbolRefAttr::get(action));
1424 } while (parser.parseOptionalComma().succeeded());
1425
1426 matchers = parser.getBuilder().getArrayAttr(matcherList);
1427 actions = parser.getBuilder().getArrayAttr(actionList);
1428 return success();
1429}
1430
1431/// Prints the comma-separated list of symbol reference pairs of the format
1432/// `@matcher -> @action`.
1434 ArrayAttr matchers, ArrayAttr actions) {
1435 printer.increaseIndent();
1436 printer.increaseIndent();
1437 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1438 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1439 printer.printNewline();
1440 printer << cast<SymbolRefAttr>(matcher) << " -> "
1441 << cast<SymbolRefAttr>(action);
1442 if (idx != matchers.size() - 1)
1443 printer << ", ";
1444 }
1445 printer.decreaseIndent();
1446 printer.decreaseIndent();
1447}
1448
1449LogicalResult transform::ForeachMatchOp::verify() {
1450 if (getMatchers().size() != getActions().size())
1451 return emitOpError() << "expected the same number of matchers and actions";
1452 if (getMatchers().empty())
1453 return emitOpError() << "expected at least one match/action pair";
1454
1456 for (Attribute name : getMatchers()) {
1457 if (matcherNames.insert(name).second)
1458 continue;
1459 emitWarning() << "matcher " << name
1460 << " is used more than once, only the first match will apply";
1461 }
1462
1463 return success();
1464}
1465
1466/// Checks that the attributes of the function-like operation have correct
1467/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1468/// annotations being present even if they can be inferred from the body.
1470verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1471 bool alsoVerifyInternal = false) {
1472 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1473 llvm::SmallDenseSet<unsigned> consumedArguments;
1474 if (!op.isExternal()) {
1475 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1476 consumedArguments);
1477 }
1478 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1479 bool isConsumed =
1480 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1481 nullptr;
1482 bool isReadOnly =
1483 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1484 nullptr;
1485 if (isConsumed && isReadOnly) {
1486 return transformOp.emitSilenceableError()
1487 << "argument #" << i << " cannot be both readonly and consumed";
1488 }
1489 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1490 return transformOp.emitSilenceableError()
1491 << "must provide consumed/readonly status for arguments of "
1492 "external or called ops";
1493 }
1494 if (op.isExternal())
1495 continue;
1496
1497 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1498 return transformOp.emitSilenceableError()
1499 << "argument #" << i
1500 << " is consumed in the body but is not marked as such";
1501 }
1502 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1503 // Cannot use op.emitWarning() here as it would attempt to verify the op
1504 // before printing, resulting in infinite recursion.
1505 emitWarning(op->getLoc())
1506 << "op argument #" << i
1507 << " is not consumed in the body but is marked as consumed";
1508 }
1509 }
1511}
1512
1513LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1514 SymbolTableCollection &symbolTable) {
1515 assert(getMatchers().size() == getActions().size());
1516 auto consumedAttr =
1517 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1518 for (auto &&[matcher, action] :
1519 llvm::zip_equal(getMatchers(), getActions())) {
1520 // Presence and typing.
1521 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1522 symbolTable.lookupNearestSymbolFrom(getOperation(),
1523 cast<SymbolRefAttr>(matcher)));
1524 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1525 symbolTable.lookupNearestSymbolFrom(getOperation(),
1526 cast<SymbolRefAttr>(action)));
1527 if (!matcherSymbol ||
1528 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1529 return emitError() << "unresolved matcher symbol " << matcher;
1530 if (!actionSymbol ||
1531 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1532 return emitError() << "unresolved action symbol " << action;
1533
1535 /*emitWarnings=*/false,
1536 /*alsoVerifyInternal=*/true)
1537 .checkAndReport())) {
1538 return failure();
1539 }
1541 /*emitWarnings=*/false,
1542 /*alsoVerifyInternal=*/true)
1543 .checkAndReport())) {
1544 return failure();
1545 }
1546
1547 // Input -> matcher forwarding.
1548 TypeRange operandTypes = getOperandTypes();
1549 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1550 if (operandTypes.size() != matcherArguments.size()) {
1552 emitError() << "the number of operands (" << operandTypes.size()
1553 << ") doesn't match the number of matcher arguments ("
1554 << matcherArguments.size() << ") for " << matcher;
1555 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1556 return diag;
1557 }
1558 for (auto &&[i, operand, argument] :
1559 llvm::enumerate(operandTypes, matcherArguments)) {
1560 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1562 emitOpError()
1563 << "does not expect matcher symbol to consume its operand #" << i;
1564 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1565 return diag;
1566 }
1567
1568 if (implementSameTransformInterface(operand, argument))
1569 continue;
1570
1572 emitError()
1573 << "mismatching type interfaces for operand and matcher argument #"
1574 << i << " of matcher " << matcher;
1575 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1576 return diag;
1577 }
1578
1579 // Matcher -> action forwarding.
1580 TypeRange matcherResults = matcherSymbol.getResultTypes();
1581 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1582 if (matcherResults.size() != actionArguments.size()) {
1583 return emitError() << "mismatching number of matcher results and "
1584 "action arguments between "
1585 << matcher << " (" << matcherResults.size() << ") and "
1586 << action << " (" << actionArguments.size() << ")";
1587 }
1588 for (auto &&[i, matcherType, actionType] :
1589 llvm::enumerate(matcherResults, actionArguments)) {
1590 if (implementSameTransformInterface(matcherType, actionType))
1591 continue;
1592
1593 return emitError() << "mismatching type interfaces for matcher result "
1594 "and action argument #"
1595 << i << "of matcher " << matcher << " and action "
1596 << action;
1597 }
1598
1599 // Action -> result forwarding.
1600 TypeRange actionResults = actionSymbol.getResultTypes();
1601 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1602 if (actionResults.size() != resultTypes.size()) {
1604 emitError() << "the number of action results ("
1605 << actionResults.size() << ") for " << action
1606 << " doesn't match the number of extra op results ("
1607 << resultTypes.size() << ")";
1608 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1609 return diag;
1610 }
1611 for (auto &&[i, resultType, actionType] :
1612 llvm::enumerate(resultTypes, actionResults)) {
1613 if (implementSameTransformInterface(resultType, actionType))
1614 continue;
1615
1617 emitError() << "mismatching type interfaces for action result #" << i
1618 << " of action " << action << " and op result";
1619 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1620 return diag;
1621 }
1622 }
1623 return success();
1624}
1625
1626//===----------------------------------------------------------------------===//
1627// ForeachOp
1628//===----------------------------------------------------------------------===//
1629
1631transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1634 // We store the payloads before executing the body as ops may be removed from
1635 // the mapping by the TrackingRewriter while iteration is in progress.
1637 detail::prepareValueMappings(payloads, getTargets(), state);
1638 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1639 bool withZipShortest = getWithZipShortest();
1640
1641 // In case of `zip_shortest`, set the number of iterations to the
1642 // smallest payload in the targets.
1643 if (withZipShortest) {
1644 numIterations =
1645 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1646 const SmallVector<MappedValue> &b) {
1647 return a.size() < b.size();
1648 })->size();
1649
1650 for (auto &payload : payloads)
1651 payload.resize(numIterations);
1652 }
1653
1654 // As we will be "zipping" over them, check all payloads have the same size.
1655 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1656 // when true.
1657 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1658 argIdx++) {
1659 if (payloads[argIdx].size() != numIterations) {
1660 return emitSilenceableError()
1661 << "prior targets' payload size (" << numIterations
1662 << ") differs from payload size (" << payloads[argIdx].size()
1663 << ") of target " << getTargets()[argIdx];
1664 }
1665 }
1666
1667 // Start iterating, indexing into payloads to obtain the right arguments to
1668 // call the body with - each slice of payloads at the same argument index
1669 // corresponding to a tuple to use as the body's block arguments.
1670 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1671 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1672 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1673 auto scope = state.make_region_scope(getBody());
1674 // Set up arguments to the region's block.
1675 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1676 MappedValue argument = payloads[argIdx][iterIdx];
1677 // Note that each blockArg's handle gets associated with just a single
1678 // element from the corresponding target's payload.
1679 if (failed(state.mapBlockArgument(blockArg, {argument})))
1681 }
1682
1683 // Execute loop body.
1684 for (Operation &transform : getBody().front().without_terminator()) {
1686 llvm::cast<transform::TransformOpInterface>(transform));
1687 if (!result.succeeded())
1688 return result;
1689 }
1690
1691 // Append yielded payloads to corresponding results from prior iterations.
1692 OperandRange yieldOperands = getYieldOp().getOperands();
1693 for (auto &&[result, yieldOperand, resTuple] :
1694 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1695 // NB: each iteration we add any number of ops/vals/params to a result.
1696 if (isa<TransformHandleTypeInterface>(result.getType()))
1697 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1698 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1699 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1700 else if (isa<TransformParamTypeInterface>(result.getType()))
1701 llvm::append_range(resTuple, state.getParams(yieldOperand));
1702 else
1703 assert(false && "unhandled handle type");
1704 }
1705
1706 // Associate the accumulated result payloads to the op's actual results.
1707 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1708 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1709
1711}
1712
1713void transform::ForeachOp::getEffects(
1715 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1716 // arity errors, this method might get called before/in absence of `verify()`.
1717 for (auto &&[target, blockArg] :
1718 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1719 BlockArgument blockArgument = blockArg;
1720 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1721 return isHandleConsumed(blockArgument,
1722 cast<TransformOpInterface>(&op));
1723 })) {
1724 consumesHandle(target, effects);
1725 } else {
1726 onlyReadsHandle(target, effects);
1727 }
1728 }
1729
1730 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1731 return doesModifyPayload(cast<TransformOpInterface>(&op));
1732 })) {
1733 modifiesPayload(effects);
1734 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1735 return doesReadPayload(cast<TransformOpInterface>(&op));
1736 })) {
1737 onlyReadsPayload(effects);
1738 }
1739
1740 producesHandle(getOperation()->getOpResults(), effects);
1741}
1742
1743void transform::ForeachOp::getSuccessorRegions(
1745 Region *bodyRegion = &getBody();
1746 if (point.isParent()) {
1747 regions.emplace_back(bodyRegion);
1748 return;
1749 }
1750
1751 // Branch back to the region or the parent.
1752 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
1753 &getBody() &&
1754 "unexpected region index");
1755 regions.emplace_back(bodyRegion);
1756 regions.push_back(RegionSuccessor::parent());
1757}
1758
1759ValueRange transform::ForeachOp::getSuccessorInputs(RegionSuccessor successor) {
1760 return successor.isParent() ? ValueRange(getResults())
1761 : ValueRange(getBody().getArguments());
1762}
1763
1765transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1766 // Each block argument handle is mapped to a subset (one op to be precise)
1767 // of the payload of the corresponding `targets` operand of ForeachOp.
1768 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1769 return getOperation()->getOperands();
1770}
1771
1772transform::YieldOp transform::ForeachOp::getYieldOp() {
1773 return cast<transform::YieldOp>(getBody().front().getTerminator());
1774}
1775
1776LogicalResult transform::ForeachOp::verify() {
1777 for (auto [targetOpt, bodyArgOpt] :
1778 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1779 if (!targetOpt || !bodyArgOpt)
1780 return emitOpError() << "expects the same number of targets as the body "
1781 "has block arguments";
1782 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1783 return emitOpError(
1784 "expects co-indexed targets and the body's "
1785 "block arguments to have the same op/value/param type");
1786 }
1787
1788 for (auto [resultOpt, yieldOperandOpt] :
1789 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1790 if (!resultOpt || !yieldOperandOpt)
1791 return emitOpError() << "expects the same number of results as the "
1792 "yield terminator has operands";
1793 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1794 return emitOpError("expects co-indexed results and yield "
1795 "operands to have the same op/value/param type");
1796 }
1797
1798 return success();
1799}
1800
1801//===----------------------------------------------------------------------===//
1802// GetParentOp
1803//===----------------------------------------------------------------------===//
1804
1806transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1810 DenseSet<Operation *> resultSet;
1811 for (Operation *target : state.getPayloadOps(getTarget())) {
1812 Operation *parent = target;
1813 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1814 parent = parent->getParentOp();
1815 while (parent) {
1816 bool checkIsolatedFromAbove =
1817 !getIsolatedFromAbove() ||
1819 bool checkOpName = !getOpName().has_value() ||
1820 parent->getName().getStringRef() == *getOpName();
1821 if (checkIsolatedFromAbove && checkOpName)
1822 break;
1823 parent = parent->getParentOp();
1824 }
1825 if (!parent) {
1826 if (getAllowEmptyResults()) {
1827 results.set(llvm::cast<OpResult>(getResult()), parents);
1829 }
1831 emitSilenceableError()
1832 << "could not find a parent op that matches all requirements";
1833 diag.attachNote(target->getLoc()) << "target op";
1834 return diag;
1835 }
1836 }
1837 if (getDeduplicate()) {
1838 if (resultSet.insert(parent).second)
1839 parents.push_back(parent);
1840 } else {
1841 parents.push_back(parent);
1842 }
1843 }
1844 results.set(llvm::cast<OpResult>(getResult()), parents);
1846}
1847
1848//===----------------------------------------------------------------------===//
1849// GetConsumersOfResult
1850//===----------------------------------------------------------------------===//
1851
1853transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1856 int64_t resultNumber = getResultNumber();
1857 auto payloadOps = state.getPayloadOps(getTarget());
1858 if (std::empty(payloadOps)) {
1859 results.set(cast<OpResult>(getResult()), {});
1861 }
1862 if (!llvm::hasSingleElement(payloadOps))
1863 return emitDefiniteFailure()
1864 << "handle must be mapped to exactly one payload op";
1865
1866 Operation *target = *payloadOps.begin();
1867 if (target->getNumResults() <= resultNumber)
1868 return emitDefiniteFailure() << "result number overflow";
1869 results.set(llvm::cast<OpResult>(getResult()),
1870 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1872}
1873
1874//===----------------------------------------------------------------------===//
1875// GetDefiningOp
1876//===----------------------------------------------------------------------===//
1877
1879transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1882 SmallVector<Operation *> definingOps;
1883 for (Value v : state.getPayloadValues(getTarget())) {
1884 if (llvm::isa<BlockArgument>(v)) {
1886 emitSilenceableError() << "cannot get defining op of block argument";
1887 diag.attachNote(v.getLoc()) << "target value";
1888 return diag;
1889 }
1890 definingOps.push_back(v.getDefiningOp());
1891 }
1892 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1894}
1895
1896//===----------------------------------------------------------------------===//
1897// GetProducerOfOperand
1898//===----------------------------------------------------------------------===//
1899
1901transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1904 int64_t operandNumber = getOperandNumber();
1905 SmallVector<Operation *> producers;
1906 for (Operation *target : state.getPayloadOps(getTarget())) {
1907 Operation *producer =
1908 target->getNumOperands() <= operandNumber
1909 ? nullptr
1910 : target->getOperand(operandNumber).getDefiningOp();
1911 if (!producer) {
1913 emitSilenceableError()
1914 << "could not find a producer for operand number: " << operandNumber
1915 << " of " << *target;
1916 diag.attachNote(target->getLoc()) << "target op";
1917 return diag;
1918 }
1919 producers.push_back(producer);
1920 }
1921 results.set(llvm::cast<OpResult>(getResult()), producers);
1923}
1924
1925//===----------------------------------------------------------------------===//
1926// GetOperandOp
1927//===----------------------------------------------------------------------===//
1928
1930transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1933 SmallVector<Value> operands;
1934 for (Operation *target : state.getPayloadOps(getTarget())) {
1935 SmallVector<int64_t> operandPositions;
1937 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1938 target->getNumOperands(), operandPositions);
1939 if (diag.isSilenceableFailure()) {
1940 diag.attachNote(target->getLoc())
1941 << "while considering positions of this payload operation";
1942 return diag;
1943 }
1944 llvm::append_range(operands,
1945 llvm::map_range(operandPositions, [&](int64_t pos) {
1946 return target->getOperand(pos);
1947 }));
1948 }
1949 results.setValues(cast<OpResult>(getResult()), operands);
1951}
1952
1953LogicalResult transform::GetOperandOp::verify() {
1954 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1955 getIsInverted(), getIsAll());
1956}
1957
1958//===----------------------------------------------------------------------===//
1959// GetResultOp
1960//===----------------------------------------------------------------------===//
1961
1963transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1966 SmallVector<Value> opResults;
1967 for (Operation *target : state.getPayloadOps(getTarget())) {
1968 SmallVector<int64_t> resultPositions;
1970 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1971 target->getNumResults(), resultPositions);
1972 if (diag.isSilenceableFailure()) {
1973 diag.attachNote(target->getLoc())
1974 << "while considering positions of this payload operation";
1975 return diag;
1976 }
1977 llvm::append_range(opResults,
1978 llvm::map_range(resultPositions, [&](int64_t pos) {
1979 return target->getResult(pos);
1980 }));
1981 }
1982 results.setValues(cast<OpResult>(getResult()), opResults);
1984}
1985
1986LogicalResult transform::GetResultOp::verify() {
1987 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1988 getIsInverted(), getIsAll());
1989}
1990
1991//===----------------------------------------------------------------------===//
1992// GetTypeOp
1993//===----------------------------------------------------------------------===//
1994
1995void transform::GetTypeOp::getEffects(
1997 onlyReadsHandle(getValueMutable(), effects);
1998 producesHandle(getOperation()->getOpResults(), effects);
1999 onlyReadsPayload(effects);
2000}
2001
2003transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
2007 for (Value value : state.getPayloadValues(getValue())) {
2008 Type type = value.getType();
2009 if (getElemental()) {
2010 if (auto shaped = dyn_cast<ShapedType>(type)) {
2011 type = shaped.getElementType();
2012 }
2013 }
2014 params.push_back(TypeAttr::get(type));
2015 }
2016 results.setParams(cast<OpResult>(getResult()), params);
2018}
2019
2020//===----------------------------------------------------------------------===//
2021// IncludeOp
2022//===----------------------------------------------------------------------===//
2023
2024/// Applies the transform ops contained in `block`. Maps `results` to the same
2025/// values as the operands of the block terminator.
2027applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2029 transform::TransformResults &results) {
2030 // Apply the sequenced ops one by one.
2031 for (Operation &transform : block.without_terminator()) {
2033 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2034 if (result.isDefiniteFailure())
2035 return result;
2036
2037 if (result.isSilenceableFailure()) {
2038 if (mode == transform::FailurePropagationMode::Propagate) {
2039 // Propagate empty results in case of early exit.
2040 forwardEmptyOperands(&block, state, results);
2041 return result;
2042 }
2043 (void)result.silence();
2044 }
2045 }
2046
2047 // Forward the operation mapping for values yielded from the sequence to the
2048 // values produced by the sequence op.
2049 transform::detail::forwardTerminatorOperands(&block, state, results);
2051}
2052
2054transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2058 getOperation(), getTarget());
2059 assert(callee && "unverified reference to unknown symbol");
2060
2061 if (callee.isExternal())
2062 return emitDefiniteFailure() << "unresolved external named sequence";
2063
2064 // Map operands to block arguments.
2066 detail::prepareValueMappings(mappings, getOperands(), state);
2067 auto scope = state.make_region_scope(callee.getBody());
2068 for (auto &&[arg, map] :
2069 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2070 if (failed(state.mapBlockArgument(arg, map)))
2072 }
2073
2075 callee.getBody().front(), getFailurePropagationMode(), state, results);
2076
2077 if (!result.succeeded())
2078 return result;
2079
2080 mappings.clear();
2081 detail::prepareValueMappings(
2082 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2083 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2084 results.setMappedValues(result, mapping);
2085 return result;
2086}
2087
2089verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2090
2091void transform::IncludeOp::getEffects(
2093 // Always mark as modifying the payload.
2094 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2095 // only read, the payload may still be modified, so we currently stay on the
2096 // conservative side and always indicate modification. This may prevent some
2097 // code reordering.
2098 modifiesPayload(effects);
2099
2100 // Results are always produced.
2101 producesHandle(getOperation()->getOpResults(), effects);
2102
2103 // Adds default effects to operands and results. This will be added if
2104 // preconditions fail so the trait verifier doesn't complain about missing
2105 // effects and the real precondition failure is reported later on.
2106 auto defaultEffects = [&] {
2107 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2108 };
2109
2110 // Bail if the callee is unknown. This may run as part of the verification
2111 // process before we verified the validity of the callee or of this op.
2112 auto target =
2113 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2114 if (!target)
2115 return defaultEffects();
2117 getOperation(), getTarget());
2118 if (!callee)
2119 return defaultEffects();
2120
2121 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2122 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2123 consumesHandle(getOperation()->getOpOperand(i), effects);
2124 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2125 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2126 }
2127}
2128
2129LogicalResult
2130transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2131 // Access through indirection and do additional checking because this may be
2132 // running before the main op verifier.
2133 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2134 if (!targetAttr)
2135 return emitOpError() << "expects a 'target' symbol reference attribute";
2136
2137 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2138 *this, targetAttr);
2139 if (!target)
2140 return emitOpError() << "does not reference a named transform sequence";
2141
2142 FunctionType fnType = target.getFunctionType();
2143 if (fnType.getNumInputs() != getNumOperands())
2144 return emitError("incorrect number of operands for callee");
2145
2146 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2147 if (getOperand(i).getType() != fnType.getInput(i)) {
2148 return emitOpError("operand type mismatch: expected operand type ")
2149 << fnType.getInput(i) << ", but provided "
2150 << getOperand(i).getType() << " for operand number " << i;
2151 }
2152 }
2153
2154 if (fnType.getNumResults() != getNumResults())
2155 return emitError("incorrect number of results for callee");
2156
2157 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2158 Type resultType = getResult(i).getType();
2159 Type funcType = fnType.getResult(i);
2160 if (!implementSameTransformInterface(resultType, funcType)) {
2161 return emitOpError() << "type of result #" << i
2162 << " must implement the same transform dialect "
2163 "interface as the corresponding callee result";
2164 }
2165 }
2166
2168 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2169 /*alsoVerifyInternal=*/true)
2170 .checkAndReport();
2171}
2172
2173//===----------------------------------------------------------------------===//
2174// MatchOperationEmptyOp
2175//===----------------------------------------------------------------------===//
2176
2177DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2178 ::std::optional<::mlir::Operation *> maybeCurrent,
2180 if (!maybeCurrent.has_value()) {
2181 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2183 }
2184 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2185 return emitSilenceableError() << "operation is not empty";
2186}
2187
2188//===----------------------------------------------------------------------===//
2189// MatchOperationNameOp
2190//===----------------------------------------------------------------------===//
2191
2192DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2193 Operation *current, transform::TransformResults &results,
2195 StringRef currentOpName = current->getName().getStringRef();
2196 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2197 if (acceptedAttr.getValue() == currentOpName)
2199 }
2200 return emitSilenceableError() << "wrong operation name";
2201}
2202
2203//===----------------------------------------------------------------------===//
2204// MatchParamCmpIOp
2205//===----------------------------------------------------------------------===//
2206
2208transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2211 auto signedAPIntAsString = [&](const APInt &value) {
2212 std::string str;
2213 llvm::raw_string_ostream os(str);
2214 value.print(os, /*isSigned=*/true);
2215 return str;
2216 };
2217
2218 ArrayRef<Attribute> params = state.getParams(getParam());
2219 ArrayRef<Attribute> references = state.getParams(getReference());
2220
2221 if (params.size() != references.size()) {
2222 return emitSilenceableError()
2223 << "parameters have different payload lengths (" << params.size()
2224 << " vs " << references.size() << ")";
2225 }
2226
2227 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2228 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2229 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2230 if (!intAttr || !refAttr) {
2231 return emitDefiniteFailure()
2232 << "non-integer parameter value not expected";
2233 }
2234 if (intAttr.getType() != refAttr.getType()) {
2235 return emitDefiniteFailure()
2236 << "mismatching integer attribute types in parameter #" << i;
2237 }
2238 APInt value = intAttr.getValue();
2239 APInt refValue = refAttr.getValue();
2240
2241 // TODO: this copy will not be necessary in C++20.
2242 int64_t position = i;
2243 auto reportError = [&](StringRef direction) {
2245 emitSilenceableError() << "expected parameter to be " << direction
2246 << " " << signedAPIntAsString(refValue)
2247 << ", got " << signedAPIntAsString(value);
2248 diag.attachNote(getParam().getLoc())
2249 << "value # " << position
2250 << " associated with the parameter defined here";
2251 return diag;
2252 };
2253
2254 switch (getPredicate()) {
2255 case MatchCmpIPredicate::eq:
2256 if (value.eq(refValue))
2257 break;
2258 return reportError("equal to");
2259 case MatchCmpIPredicate::ne:
2260 if (value.ne(refValue))
2261 break;
2262 return reportError("not equal to");
2263 case MatchCmpIPredicate::lt:
2264 if (value.slt(refValue))
2265 break;
2266 return reportError("less than");
2267 case MatchCmpIPredicate::le:
2268 if (value.sle(refValue))
2269 break;
2270 return reportError("less than or equal to");
2271 case MatchCmpIPredicate::gt:
2272 if (value.sgt(refValue))
2273 break;
2274 return reportError("greater than");
2275 case MatchCmpIPredicate::ge:
2276 if (value.sge(refValue))
2277 break;
2278 return reportError("greater than or equal to");
2279 }
2280 }
2282}
2283
2284void transform::MatchParamCmpIOp::getEffects(
2286 onlyReadsHandle(getParamMutable(), effects);
2287 onlyReadsHandle(getReferenceMutable(), effects);
2288}
2289
2290//===----------------------------------------------------------------------===//
2291// ParamConstantOp
2292//===----------------------------------------------------------------------===//
2293
2295transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2298 results.setParams(cast<OpResult>(getParam()), {getValue()});
2300}
2301
2302//===----------------------------------------------------------------------===//
2303// MergeHandlesOp
2304//===----------------------------------------------------------------------===//
2305
2307transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2310 ValueRange handles = getHandles();
2311 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2312 SmallVector<Operation *> operations;
2313 for (Value operand : handles)
2314 llvm::append_range(operations, state.getPayloadOps(operand));
2315 if (!getDeduplicate()) {
2316 results.set(llvm::cast<OpResult>(getResult()), operations);
2318 }
2319
2320 SetVector<Operation *> uniqued(llvm::from_range, operations);
2321 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2323 }
2324
2325 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2327 for (Value attribute : handles)
2328 llvm::append_range(attrs, state.getParams(attribute));
2329 if (!getDeduplicate()) {
2330 results.setParams(cast<OpResult>(getResult()), attrs);
2332 }
2333
2334 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2335 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2337 }
2338
2339 assert(
2340 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2341 "expected value handle type");
2342 SmallVector<Value> payloadValues;
2343 for (Value value : handles)
2344 llvm::append_range(payloadValues, state.getPayloadValues(value));
2345 if (!getDeduplicate()) {
2346 results.setValues(cast<OpResult>(getResult()), payloadValues);
2348 }
2349
2350 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2351 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2353}
2354
2355bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2356 // Handles may be the same if deduplicating is enabled.
2357 return getDeduplicate();
2358}
2359
2360void transform::MergeHandlesOp::getEffects(
2362 onlyReadsHandle(getHandlesMutable(), effects);
2363 producesHandle(getOperation()->getOpResults(), effects);
2364
2365 // There are no effects on the Payload IR as this is only a handle
2366 // manipulation.
2367}
2368
2369OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2370 if (getDeduplicate() || getHandles().size() != 1)
2371 return {};
2372
2373 // If deduplication is not required and there is only one operand, it can be
2374 // used directly instead of merging.
2375 return getHandles().front();
2376}
2377
2378//===----------------------------------------------------------------------===//
2379// NamedSequenceOp
2380//===----------------------------------------------------------------------===//
2381
2383transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2386 if (isExternal())
2387 return emitDefiniteFailure() << "unresolved external named sequence";
2388
2389 // Map the entry block argument to the list of operations.
2390 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2391 // without attaching the interface / trait since that is tailored to a
2392 // dangling top-level op that does not get "called".
2393 auto scope = state.make_region_scope(getBody());
2394 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2395 state, this->getOperation(), getBody())))
2397
2398 return applySequenceBlock(getBody().front(),
2399 FailurePropagationMode::Propagate, state, results);
2400}
2401
2402void transform::NamedSequenceOp::getEffects(
2404
2405ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2408 parser, result, /*allowVariadic=*/false,
2409 getFunctionTypeAttrName(result.name),
2410 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2412 std::string &) { return builder.getFunctionType(inputs, results); },
2413 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2414}
2415
2416void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2418 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2419 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2420 getResAttrsAttrName());
2421}
2422
2423/// Verifies that a symbol function-like transform dialect operation has the
2424/// signature and the terminator that have conforming types, i.e., types
2425/// implementing the same transform dialect type interface. If `allowExternal`
2426/// is set, allow external symbols (declarations) and don't check the terminator
2427/// as it may not exist.
2429verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2430 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2433 << "cannot be defined inside another transform op";
2434 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2435 return diag;
2436 }
2437
2438 if (op.isExternal() || op.getFunctionBody().empty()) {
2439 if (allowExternal)
2441
2442 return emitSilenceableFailure(op) << "cannot be external";
2443 }
2444
2445 if (op.getFunctionBody().front().empty())
2446 return emitSilenceableFailure(op) << "expected a non-empty body block";
2447
2448 Operation *terminator = &op.getFunctionBody().front().back();
2449 if (!isa<transform::YieldOp>(terminator)) {
2451 << "expected '"
2452 << transform::YieldOp::getOperationName()
2453 << "' as terminator";
2454 diag.attachNote(terminator->getLoc()) << "terminator";
2455 return diag;
2456 }
2457
2458 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2459 return emitSilenceableFailure(terminator)
2460 << "expected terminator to have as many operands as the parent op "
2461 "has results";
2462 }
2463 for (auto [i, operandType, resultType] : llvm::zip_equal(
2464 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2465 terminator->getOperands().getType(), op.getResultTypes())) {
2466 if (operandType == resultType)
2467 continue;
2468 return emitSilenceableFailure(terminator)
2469 << "the type of the terminator operand #" << i
2470 << " must match the type of the corresponding parent op result ("
2471 << operandType << " vs " << resultType << ")";
2472 }
2473
2475}
2476
2477/// Verification of a NamedSequenceOp. This does not report the error
2478/// immediately, so it can be used to check for op's well-formedness before the
2479/// verifier runs, e.g., during trait verification.
2481verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2482 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2483 if (!parent->getAttr(
2484 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2487 << "expects the parent symbol table to have the '"
2488 << transform::TransformDialect::kWithNamedSequenceAttrName
2489 << "' attribute";
2490 diag.attachNote(parent->getLoc()) << "symbol table operation";
2491 return diag;
2492 }
2493 }
2494
2495 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2498 << "cannot be defined inside another transform op";
2499 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2500 return diag;
2501 }
2502
2503 if (op.isExternal() || op.getBody().empty())
2504 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2505 emitWarnings);
2506
2507 if (op.getBody().front().empty())
2508 return emitSilenceableFailure(op) << "expected a non-empty body block";
2509
2510 // Check that all operations in the body implement TransformOpInterface
2511 for (Operation &child : op.getBody().front().without_terminator()) {
2512 if (!isa<transform::TransformOpInterface>(child)) {
2515 << "expected children ops to implement TransformOpInterface";
2516 diag.attachNote(child.getLoc()) << "op without interface";
2517 return diag;
2518 }
2519 }
2520
2521 Operation *terminator = &op.getBody().front().back();
2522 if (!isa<transform::YieldOp>(terminator)) {
2524 << "expected '"
2525 << transform::YieldOp::getOperationName()
2526 << "' as terminator";
2527 diag.attachNote(terminator->getLoc()) << "terminator";
2528 return diag;
2529 }
2530
2531 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2532 return emitSilenceableFailure(terminator)
2533 << "expected terminator to have as many operands as the parent op "
2534 "has results";
2535 }
2536 for (auto [i, operandType, resultType] :
2537 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2538 terminator->getOperands().getType(),
2539 op.getFunctionType().getResults())) {
2540 if (operandType == resultType)
2541 continue;
2542 return emitSilenceableFailure(terminator)
2543 << "the type of the terminator operand #" << i
2544 << " must match the type of the corresponding parent op result ("
2545 << operandType << " vs " << resultType << ")";
2546 }
2547
2548 auto funcOp = cast<FunctionOpInterface>(*op);
2550 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2551 if (!diag.succeeded())
2552 return diag;
2553
2554 return verifyYieldingSingleBlockOp(funcOp,
2555 /*allowExternal=*/true);
2556}
2557
2558LogicalResult transform::NamedSequenceOp::verify() {
2559 // Actual verification happens in a separate function for reusability.
2560 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2561}
2562
2563template <typename FnTy>
2564static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2565 Type bbArgType, TypeRange extraBindingTypes,
2566 FnTy bodyBuilder) {
2567 SmallVector<Type> types;
2568 types.reserve(1 + extraBindingTypes.size());
2569 types.push_back(bbArgType);
2570 llvm::append_range(types, extraBindingTypes);
2571
2572 OpBuilder::InsertionGuard guard(builder);
2573 Region *region = state.regions.back().get();
2574 Block *bodyBlock =
2575 builder.createBlock(region, region->begin(), types,
2576 SmallVector<Location>(types.size(), state.location));
2577
2578 // Populate body.
2579 builder.setInsertionPointToStart(bodyBlock);
2580 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2581 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2582 } else {
2583 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2584 bodyBlock->getArguments().drop_front());
2585 }
2586}
2587
2588void transform::NamedSequenceOp::build(OpBuilder &builder,
2589 OperationState &state, StringRef symName,
2590 Type rootType, TypeRange resultTypes,
2591 SequenceBodyBuilderFn bodyBuilder,
2593 ArrayRef<DictionaryAttr> argAttrs) {
2595 builder.getStringAttr(symName));
2596 state.addAttribute(getFunctionTypeAttrName(state.name),
2597 TypeAttr::get(FunctionType::get(builder.getContext(),
2598 rootType, resultTypes)));
2599 state.attributes.append(attrs.begin(), attrs.end());
2600 state.addRegion();
2601
2602 buildSequenceBody(builder, state, rootType,
2603 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2604}
2605
2606//===----------------------------------------------------------------------===//
2607// NumAssociationsOp
2608//===----------------------------------------------------------------------===//
2609
2611transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2614 size_t numAssociations =
2616 .Case([&](TransformHandleTypeInterface opHandle) {
2617 return llvm::range_size(state.getPayloadOps(getHandle()));
2618 })
2619 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2620 return llvm::range_size(state.getPayloadValues(getHandle()));
2621 })
2622 .Case([&](TransformParamTypeInterface param) {
2623 return llvm::range_size(state.getParams(getHandle()));
2624 })
2625 .DefaultUnreachable("unknown kind of transform dialect type");
2626 results.setParams(cast<OpResult>(getNum()),
2627 rewriter.getI64IntegerAttr(numAssociations));
2629}
2630
2631LogicalResult transform::NumAssociationsOp::verify() {
2632 // Verify that the result type accepts an i64 attribute as payload.
2633 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2634 return resultType
2635 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2636 .checkAndReport();
2637}
2638
2639//===----------------------------------------------------------------------===//
2640// SelectOp
2641//===----------------------------------------------------------------------===//
2642
2644transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2648 auto payloadOps = state.getPayloadOps(getTarget());
2649 for (Operation *op : payloadOps) {
2650 if (op->getName().getStringRef() == getOpName())
2651 result.push_back(op);
2652 }
2653 results.set(cast<OpResult>(getResult()), result);
2655}
2656
2657//===----------------------------------------------------------------------===//
2658// SplitHandleOp
2659//===----------------------------------------------------------------------===//
2660
2661void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2662 Value target, int64_t numResultHandles) {
2663 result.addOperands(target);
2664 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2665}
2666
2668transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2671 int64_t numPayloads =
2673 .Case([&](TransformHandleTypeInterface x) {
2674 return llvm::range_size(state.getPayloadOps(getHandle()));
2675 })
2676 .Case([&](TransformValueHandleTypeInterface x) {
2677 return llvm::range_size(state.getPayloadValues(getHandle()));
2678 })
2679 .Case([&](TransformParamTypeInterface x) {
2680 return llvm::range_size(state.getParams(getHandle()));
2681 })
2682 .DefaultUnreachable("unknown transform dialect type interface");
2683
2684 auto produceNumOpsError = [&]() {
2685 return emitSilenceableError()
2686 << getHandle() << " expected to contain " << this->getNumResults()
2687 << " payloads but it contains " << numPayloads << " payloads";
2688 };
2689
2690 // Fail if there are more payload ops than results and no overflow result was
2691 // specified.
2692 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2693 return produceNumOpsError();
2694
2695 // Fail if there are more results than payload ops. Unless:
2696 // - "fail_on_payload_too_small" is set to "false", or
2697 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2698 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2699 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2700 return produceNumOpsError();
2701
2702 // Distribute payloads.
2703 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2704 if (getOverflowResult())
2705 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2706
2707 auto container = [&]() {
2708 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2709 return llvm::map_to_vector(
2710 state.getPayloadOps(getHandle()),
2711 [](Operation *op) -> MappedValue { return op; });
2712 }
2713 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2714 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2715 [](Value v) -> MappedValue { return v; });
2716 }
2717 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2718 "unsupported kind of transform dialect type");
2719 return llvm::map_to_vector(state.getParams(getHandle()),
2720 [](Attribute a) -> MappedValue { return a; });
2721 }();
2722
2723 for (auto &&en : llvm::enumerate(container)) {
2724 int64_t resultNum = en.index();
2725 if (resultNum >= getNumResults())
2726 resultNum = *getOverflowResult();
2727 resultHandles[resultNum].push_back(en.value());
2728 }
2729
2730 // Set transform op results.
2731 for (auto &&it : llvm::enumerate(resultHandles))
2732 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2733 it.value());
2734
2736}
2737
2738void transform::SplitHandleOp::getEffects(
2740 onlyReadsHandle(getHandleMutable(), effects);
2741 producesHandle(getOperation()->getOpResults(), effects);
2742 // There are no effects on the Payload IR as this is only a handle
2743 // manipulation.
2744}
2745
2746LogicalResult transform::SplitHandleOp::verify() {
2747 if (getOverflowResult().has_value() &&
2748 !(*getOverflowResult() < getNumResults()))
2749 return emitOpError("overflow_result is not a valid result index");
2750
2751 for (Type resultType : getResultTypes()) {
2752 if (implementSameTransformInterface(getHandle().getType(), resultType))
2753 continue;
2754
2755 return emitOpError("expects result types to implement the same transform "
2756 "interface as the operand type");
2757 }
2758
2759 return success();
2760}
2761
2762//===----------------------------------------------------------------------===//
2763// ReplicateOp
2764//===----------------------------------------------------------------------===//
2765
2767transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2770 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2771 for (const auto &en : llvm::enumerate(getHandles())) {
2772 Value handle = en.value();
2773 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2774 SmallVector<Operation *> current =
2775 llvm::to_vector(state.getPayloadOps(handle));
2777 payload.reserve(numRepetitions * current.size());
2778 for (unsigned i = 0; i < numRepetitions; ++i)
2779 llvm::append_range(payload, current);
2780 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2781 } else {
2782 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2783 "expected param type");
2784 ArrayRef<Attribute> current = state.getParams(handle);
2786 params.reserve(numRepetitions * current.size());
2787 for (unsigned i = 0; i < numRepetitions; ++i)
2788 llvm::append_range(params, current);
2789 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2790 params);
2791 }
2792 }
2794}
2795
2796void transform::ReplicateOp::getEffects(
2798 onlyReadsHandle(getPatternMutable(), effects);
2799 onlyReadsHandle(getHandlesMutable(), effects);
2800 producesHandle(getOperation()->getOpResults(), effects);
2801}
2802
2803//===----------------------------------------------------------------------===//
2804// SequenceOp
2805//===----------------------------------------------------------------------===//
2806
2808transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2811 // Map the entry block argument to the list of operations.
2812 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2813 if (failed(mapBlockArguments(state)))
2815
2816 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2817 results);
2818}
2819
2820static ParseResult parseSequenceOpOperands(
2821 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2822 Type &rootType,
2824 SmallVectorImpl<Type> &extraBindingTypes) {
2826 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2827 if (!hasRoot.has_value()) {
2828 root = std::nullopt;
2829 return success();
2830 }
2831 if (failed(hasRoot.value()))
2832 return failure();
2833 root = rootOperand;
2834
2835 if (succeeded(parser.parseOptionalComma())) {
2836 if (failed(parser.parseOperandList(extraBindings)))
2837 return failure();
2838 }
2839 if (failed(parser.parseColon()))
2840 return failure();
2841
2842 // The paren is truly optional.
2843 (void)parser.parseOptionalLParen();
2844
2845 if (failed(parser.parseType(rootType))) {
2846 return failure();
2847 }
2848
2849 if (!extraBindings.empty()) {
2850 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2851 return failure();
2852 }
2853
2854 if (extraBindingTypes.size() != extraBindings.size()) {
2855 return parser.emitError(parser.getNameLoc(),
2856 "expected types to be provided for all operands");
2857 }
2858
2859 // The paren is truly optional.
2860 (void)parser.parseOptionalRParen();
2861 return success();
2862}
2863
2865 Value root, Type rootType,
2866 ValueRange extraBindings,
2867 TypeRange extraBindingTypes) {
2868 if (!root)
2869 return;
2870
2871 printer << root;
2872 bool hasExtras = !extraBindings.empty();
2873 if (hasExtras) {
2874 printer << ", ";
2875 printer.printOperands(extraBindings);
2876 }
2877
2878 printer << " : ";
2879 if (hasExtras)
2880 printer << "(";
2881
2882 printer << rootType;
2883 if (hasExtras)
2884 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2885}
2886
2887/// Returns `true` if the given op operand may be consuming the handle value in
2888/// the Transform IR. That is, if it may have a Free effect on it.
2890 // Conservatively assume the effect being present in absence of the interface.
2891 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2892 if (!iface)
2893 return true;
2894
2895 return isHandleConsumed(use.get(), iface);
2896}
2897
2898static LogicalResult
2900 function_ref<InFlightDiagnostic()> reportError) {
2901 OpOperand *potentialConsumer = nullptr;
2902 for (OpOperand &use : value.getUses()) {
2904 continue;
2905
2906 if (!potentialConsumer) {
2907 potentialConsumer = &use;
2908 continue;
2909 }
2910
2911 InFlightDiagnostic diag = reportError()
2912 << " has more than one potential consumer";
2913 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2914 << "used here as operand #" << potentialConsumer->getOperandNumber();
2915 diag.attachNote(use.getOwner()->getLoc())
2916 << "used here as operand #" << use.getOperandNumber();
2917 return diag;
2918 }
2919
2920 return success();
2921}
2922
2923LogicalResult transform::SequenceOp::verify() {
2924 assert(getBodyBlock()->getNumArguments() >= 1 &&
2925 "the number of arguments must have been verified to be more than 1 by "
2926 "PossibleTopLevelTransformOpTrait");
2927
2928 if (!getRoot() && !getExtraBindings().empty()) {
2929 return emitOpError()
2930 << "does not expect extra operands when used as top-level";
2931 }
2932
2933 // Check if a block argument has more than one consuming use.
2934 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2935 if (failed(checkDoubleConsume(arg, [this, arg]() {
2936 return (emitOpError() << "block argument #" << arg.getArgNumber());
2937 }))) {
2938 return failure();
2939 }
2940 }
2941
2942 // Check properties of the nested operations they cannot check themselves.
2943 for (Operation &child : *getBodyBlock()) {
2944 if (!isa<TransformOpInterface>(child) &&
2945 &child != &getBodyBlock()->back()) {
2947 emitOpError()
2948 << "expected children ops to implement TransformOpInterface";
2949 diag.attachNote(child.getLoc()) << "op without interface";
2950 return diag;
2951 }
2952
2953 for (OpResult result : child.getResults()) {
2954 auto report = [&]() {
2955 return (child.emitError() << "result #" << result.getResultNumber());
2956 };
2957 if (failed(checkDoubleConsume(result, report)))
2958 return failure();
2959 }
2960 }
2961
2962 if (!getBodyBlock()->mightHaveTerminator())
2963 return emitOpError() << "expects to have a terminator in the body";
2964
2965 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2966 getOperation()->getResultTypes()) {
2968 << "expects the types of the terminator operands "
2969 "to match the types of the result";
2970 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2971 return diag;
2972 }
2973 return success();
2974}
2975
2976void transform::SequenceOp::getEffects(
2979}
2980
2982transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2983 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2984 if (getOperation()->getNumOperands() > 0)
2985 return getOperation()->getOperands();
2986 return OperandRange(getOperation()->operand_end(),
2987 getOperation()->operand_end());
2988}
2989
2990void transform::SequenceOp::getSuccessorRegions(
2992 if (point.isParent()) {
2993 Region *bodyRegion = &getBody();
2994 regions.emplace_back(bodyRegion);
2995 return;
2996 }
2997
2998 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2999 &getBody() &&
3000 "unexpected region index");
3001 regions.push_back(RegionSuccessor::parent());
3002}
3003
3005transform::SequenceOp::getSuccessorInputs(RegionSuccessor successor) {
3006 if (getNumOperands() == 0)
3007 return ValueRange();
3008 if (successor.isParent())
3009 return getResults();
3010 return getBody().getArguments();
3011}
3012
3013void transform::SequenceOp::getRegionInvocationBounds(
3015 (void)operands;
3016 bounds.emplace_back(1, 1);
3017}
3018
3019void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3020 TypeRange resultTypes,
3021 FailurePropagationMode failurePropagationMode,
3022 Value root,
3023 SequenceBodyBuilderFn bodyBuilder) {
3024 build(builder, state, resultTypes, failurePropagationMode, root,
3025 /*extra_bindings=*/ValueRange());
3026 Type bbArgType = root.getType();
3027 buildSequenceBody(builder, state, bbArgType,
3028 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3029}
3030
3031void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3032 TypeRange resultTypes,
3033 FailurePropagationMode failurePropagationMode,
3034 Value root, ValueRange extraBindings,
3035 SequenceBodyBuilderArgsFn bodyBuilder) {
3036 build(builder, state, resultTypes, failurePropagationMode, root,
3037 extraBindings);
3038 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3039 bodyBuilder);
3040}
3041
3042void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3043 TypeRange resultTypes,
3044 FailurePropagationMode failurePropagationMode,
3045 Type bbArgType,
3046 SequenceBodyBuilderFn bodyBuilder) {
3047 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3048 /*extra_bindings=*/ValueRange());
3049 buildSequenceBody(builder, state, bbArgType,
3050 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3051}
3052
3053void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3054 TypeRange resultTypes,
3055 FailurePropagationMode failurePropagationMode,
3056 Type bbArgType, TypeRange extraBindingTypes,
3057 SequenceBodyBuilderArgsFn bodyBuilder) {
3058 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3059 /*extra_bindings=*/ValueRange());
3060 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3061}
3062
3063//===----------------------------------------------------------------------===//
3064// PrintOp
3065//===----------------------------------------------------------------------===//
3066
3067void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3068 StringRef name) {
3069 if (!name.empty())
3070 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3071}
3072
3073void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3074 Value target, StringRef name) {
3075 result.addOperands({target});
3076 build(builder, result, name);
3077}
3078
3080transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3083 llvm::outs() << "[[[ IR printer: ";
3084 if (getName().has_value())
3085 llvm::outs() << *getName() << " ";
3086
3087 OpPrintingFlags printFlags;
3088 if (getAssumeVerified().value_or(false))
3089 printFlags.assumeVerified();
3090 if (getUseLocalScope().value_or(false))
3091 printFlags.useLocalScope();
3092 if (getSkipRegions().value_or(false))
3093 printFlags.skipRegions();
3094
3095 if (!getTarget()) {
3096 llvm::outs() << "top-level ]]]\n";
3097 state.getTopLevel()->print(llvm::outs(), printFlags);
3098 llvm::outs() << "\n";
3099 llvm::outs().flush();
3101 }
3102
3103 llvm::outs() << "]]]\n";
3104 for (Operation *target : state.getPayloadOps(getTarget())) {
3105 target->print(llvm::outs(), printFlags);
3106 llvm::outs() << "\n";
3107 }
3108
3109 llvm::outs().flush();
3111}
3112
3113void transform::PrintOp::getEffects(
3115 // We don't really care about mutability here, but `getTarget` now
3116 // unconditionally casts to a specific type before verification could run
3117 // here.
3118 if (!getTargetMutable().empty())
3119 onlyReadsHandle(getTargetMutable()[0], effects);
3120 onlyReadsPayload(effects);
3121
3122 // There is no resource for stderr file descriptor, so just declare print
3123 // writes into the default resource.
3124 effects.emplace_back(MemoryEffects::Write::get());
3125}
3126
3127//===----------------------------------------------------------------------===//
3128// VerifyOp
3129//===----------------------------------------------------------------------===//
3130
3132transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3138 << "failed to verify payload op";
3139 diag.attachNote(target->getLoc()) << "payload op";
3140 return diag;
3141 }
3143}
3144
3145void transform::VerifyOp::getEffects(
3147 transform::onlyReadsHandle(getTargetMutable(), effects);
3148}
3149
3150//===----------------------------------------------------------------------===//
3151// YieldOp
3152//===----------------------------------------------------------------------===//
3153
3154void transform::YieldOp::getEffects(
3156 onlyReadsHandle(getOperandsMutable(), effects);
3157}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getOpResults()
Definition Operation.h:420
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.