MLIR 23.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
35#include "llvm/ADT/DenseSet.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/ScopeExit.h"
38#include "llvm/ADT/SmallPtrSet.h"
39#include "llvm/ADT/SmallVectorExtras.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/DebugLog.h"
43#include "llvm/Support/ErrorHandling.h"
44#include "llvm/Support/InterleavedRange.h"
45#include <optional>
46
47#define DEBUG_TYPE "transform-dialect"
48#define DEBUG_TYPE_MATCHER "transform-matcher"
49
50using namespace mlir;
51
52static ParseResult parseApplyRegisteredPassOptions(
53 OpAsmParser &parser, DictionaryAttr &options,
56 Operation *op,
57 DictionaryAttr options,
58 ValueRange dynamicOptions);
59static ParseResult parseSequenceOpOperands(
60 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
61 Type &rootType,
63 SmallVectorImpl<Type> &extraBindingTypes);
64static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
65 Value root, Type rootType,
66 ValueRange extraBindings,
67 TypeRange extraBindingTypes);
68static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
70static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
72 ArrayAttr &actions);
73
74/// Helper function to check if the given transform op is contained in (or
75/// equal to) the given payload target op. In that case, an error is returned.
76/// Transforming transform IR that is currently executing is generally unsafe.
78ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
79 Operation *payload) {
80 Operation *transformAncestor = transform.getOperation();
81 while (transformAncestor) {
82 if (transformAncestor == payload) {
84 transform.emitDefiniteFailure()
85 << "cannot apply transform to itself (or one of its ancestors)";
86 diag.attachNote(payload->getLoc()) << "target payload op";
87 return diag;
88 }
89 transformAncestor = transformAncestor->getParentOp();
90 }
92}
93
94#define GET_OP_CLASSES
95#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
96
97//===----------------------------------------------------------------------===//
98// AlternativesOp
99//===----------------------------------------------------------------------===//
100
101OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
102 RegionSuccessor successor) {
103 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
104 return getOperation()->getOperands();
105 return OperandRange(getOperation()->operand_end(),
106 getOperation()->operand_end());
107}
108
109void transform::AlternativesOp::getSuccessorRegions(
111 for (Region &alternative : llvm::drop_begin(
112 getAlternatives(), point.isParent()
113 ? 0
115 ->getParentRegion()
116 ->getRegionNumber() +
117 1)) {
118 regions.emplace_back(&alternative);
119 }
120 if (!point.isParent())
121 regions.push_back(RegionSuccessor::parent());
122}
123
125transform::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) {
126 if (successor.isParent())
127 return getOperation()->getResults();
128 return successor.getSuccessor()->getArguments();
129}
130
131void transform::AlternativesOp::getRegionInvocationBounds(
133 (void)operands;
134 // The region corresponding to the first alternative is always executed, the
135 // remaining may or may not be executed.
136 bounds.reserve(getNumRegions());
137 bounds.emplace_back(1, 1);
138 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
139}
140
143 for (const auto &res : block->getParentOp()->getOpResults())
144 results.set(res, {});
145}
146
148transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
151 SmallVector<Operation *> originals;
152 if (Value scopeHandle = getScope())
153 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
154 else
155 originals.push_back(state.getTopLevel());
156
157 for (Operation *original : originals) {
158 if (original->isAncestor(getOperation())) {
160 << "scope must not contain the transforms being applied";
161 diag.attachNote(original->getLoc()) << "scope";
162 return diag;
163 }
164 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
166 << "only isolated-from-above ops can be alternative scopes";
167 diag.attachNote(original->getLoc()) << "scope";
168 return diag;
169 }
170 }
171
172 for (Region &reg : getAlternatives()) {
173 // Clone the scope operations and make the transforms in this alternative
174 // region apply to them by virtue of mapping the block argument (the only
175 // visible handle) to the cloned scope operations. This effectively prevents
176 // the transformation from accessing any IR outside the scope.
177 auto scope = state.make_region_scope(reg);
178 auto clones = llvm::map_to_vector(
179 originals, [](Operation *op) { return op->clone(); });
180 llvm::scope_exit deleteClones([&] {
181 for (Operation *clone : clones)
182 clone->erase();
183 });
184 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
186
187 bool failed = false;
188 for (Operation &transform : reg.front().without_terminator()) {
190 state.applyTransform(cast<TransformOpInterface>(transform));
191 if (result.isSilenceableFailure()) {
192 LDBG() << "alternative failed: " << result.getMessage();
193 failed = true;
194 break;
195 }
196
197 if (::mlir::failed(result.silence()))
199 }
200
201 // If all operations in the given alternative succeeded, no need to consider
202 // the rest. Replace the original scoping operation with the clone on which
203 // the transformations were performed.
204 if (!failed) {
205 // We will be using the clones, so cancel their scheduled deletion.
206 deleteClones.release();
207 TrackingListener listener(state, *this);
208 IRRewriter rewriter(getContext(), &listener);
209 for (const auto &kvp : llvm::zip(originals, clones)) {
210 Operation *original = std::get<0>(kvp);
211 Operation *clone = std::get<1>(kvp);
212 original->getBlock()->getOperations().insert(original->getIterator(),
213 clone);
214 rewriter.replaceOp(original, clone->getResults());
215 }
216 detail::forwardTerminatorOperands(&reg.front(), state, results);
218 }
219 }
220 return emitSilenceableError() << "all alternatives failed";
221}
222
223void transform::AlternativesOp::getEffects(
225 consumesHandle(getOperation()->getOpOperands(), effects);
226 producesHandle(getOperation()->getOpResults(), effects);
227 for (Region *region : getRegions()) {
228 if (!region->empty())
229 producesHandle(region->front().getArguments(), effects);
230 }
231 modifiesPayload(effects);
232}
233
234LogicalResult transform::AlternativesOp::verify() {
235 for (Region &alternative : getAlternatives()) {
236 Block &block = alternative.front();
237 Operation *terminator = block.getTerminator();
238 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
240 << "expects terminator operands to have the "
241 "same type as results of the operation";
242 diag.attachNote(terminator->getLoc()) << "terminator";
243 return diag;
244 }
245 }
246
247 return success();
248}
249
250//===----------------------------------------------------------------------===//
251// AnnotateOp
252//===----------------------------------------------------------------------===//
253
255transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
259 llvm::to_vector(state.getPayloadOps(getTarget()));
260
261 Attribute attr = UnitAttr::get(getContext());
262 if (auto paramH = getParam()) {
263 ArrayRef<Attribute> params = state.getParams(paramH);
264 if (params.size() != 1) {
265 if (targets.size() != params.size()) {
266 return emitSilenceableError()
267 << "parameter and target have different payload lengths ("
268 << params.size() << " vs " << targets.size() << ")";
269 }
270 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
271 target->setAttr(getName(), attr);
273 }
274 attr = params[0];
275 }
276 for (auto *target : targets)
277 target->setAttr(getName(), attr);
279}
280
281void transform::AnnotateOp::getEffects(
283 onlyReadsHandle(getTargetMutable(), effects);
284 onlyReadsHandle(getParamMutable(), effects);
285 modifiesPayload(effects);
286}
287
288//===----------------------------------------------------------------------===//
289// ApplyCommonSubexpressionEliminationOp
290//===----------------------------------------------------------------------===//
291
293transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
295 ApplyToEachResultList &results, transform::TransformState &state) {
296 // Make sure that this transform is not applied to itself. Modifying the
297 // transform IR while it is being interpreted is generally dangerous.
298 DiagnosedSilenceableFailure payloadCheck =
300 if (!payloadCheck.succeeded())
301 return payloadCheck;
302
303 DominanceInfo domInfo;
306}
307
308void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
310 transform::onlyReadsHandle(getTargetMutable(), effects);
312}
313
314//===----------------------------------------------------------------------===//
315// ApplyDeadCodeEliminationOp
316//===----------------------------------------------------------------------===//
317
318DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
320 ApplyToEachResultList &results, transform::TransformState &state) {
321 // Make sure that this transform is not applied to itself. Modifying the
322 // transform IR while it is being interpreted is generally dangerous.
323 DiagnosedSilenceableFailure payloadCheck =
325 if (!payloadCheck.succeeded())
326 return payloadCheck;
327
328 for (Region &region : target->getRegions())
329 eliminateTriviallyDeadOps(rewriter, region);
330
332}
333
334void transform::ApplyDeadCodeEliminationOp::getEffects(
336 transform::onlyReadsHandle(getTargetMutable(), effects);
338}
339
340//===----------------------------------------------------------------------===//
341// ApplyPatternsOp
342//===----------------------------------------------------------------------===//
343
344DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
346 ApplyToEachResultList &results, transform::TransformState &state) {
347 // Make sure that this transform is not applied to itself. Modifying the
348 // transform IR while it is being interpreted is generally dangerous. Even
349 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
350 // performs many additional simplifications such as dead code elimination.
351 DiagnosedSilenceableFailure payloadCheck =
353 if (!payloadCheck.succeeded())
354 return payloadCheck;
355
356 // Gather all specified patterns.
357 MLIRContext *ctx = target->getContext();
358 RewritePatternSet patterns(ctx);
359 if (!getRegion().empty()) {
360 for (Operation &op : getRegion().front()) {
361 cast<transform::PatternDescriptorOpInterface>(&op)
362 .populatePatternsWithState(patterns, state);
363 }
364 }
365
366 // Configure the GreedyPatternRewriteDriver.
367 GreedyRewriteConfig config;
368 config.setListener(
369 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
370 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
371
372 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
374 : getMaxIterations());
375 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
377 : getMaxNumRewrites());
378
379 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
380 // Op is isolated from above. The greedy driver iterates to a fixpoint
381 // internally and optionally runs full CSE between iterations.
382 config.enableCSEBetweenIterations(getApplyCse());
383 if (failed(applyPatternsGreedily(target, frozenPatterns, config))) {
385 << "greedy pattern application failed";
386 }
388 }
389
390 // Non-isolated case: gather the ops manually because the op-list
391 // GreedyPatternRewriteDriver overload only performs a single iteration and
392 // does not simplify regions. CSE is driven externally to reach a fixpoint.
394 target->walk([&](Operation *nestedOp) {
395 if (target != nestedOp)
396 ops.push_back(nestedOp);
397 });
398
399 // One or two iterations should be sufficient. Stop iterating after a certain
400 // threshold to make debugging easier.
401 static const int64_t kNumMaxIterations = 50;
402 int64_t iteration = 0;
403 bool cseChanged = false;
404 do {
405 if (failed(applyOpPatternsGreedily(ops, frozenPatterns, config))) {
407 << "greedy pattern application failed";
408 }
409
410 if (getApplyCse()) {
411 DominanceInfo domInfo;
413 &cseChanged);
414 }
415 } while (cseChanged && ++iteration < kNumMaxIterations);
416
417 if (iteration == kNumMaxIterations)
418 return emitDefiniteFailure() << "fixpoint iteration did not converge";
419
421}
422
423LogicalResult transform::ApplyPatternsOp::verify() {
424 if (!getRegion().empty()) {
425 for (Operation &op : getRegion().front()) {
426 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
428 << "expected children ops to implement "
429 "PatternDescriptorOpInterface";
430 diag.attachNote(op.getLoc()) << "op without interface";
431 return diag;
432 }
433 }
434 }
435 return success();
436}
437
438void transform::ApplyPatternsOp::getEffects(
440 transform::onlyReadsHandle(getTargetMutable(), effects);
442}
443
444void transform::ApplyPatternsOp::build(
446 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
447 result.addOperands(target);
448
449 OpBuilder::InsertionGuard g(builder);
450 Region *region = result.addRegion();
451 builder.createBlock(region);
452 if (bodyBuilder)
453 bodyBuilder(builder, result.location);
454}
455
456//===----------------------------------------------------------------------===//
457// ApplyCanonicalizationPatternsOp
458//===----------------------------------------------------------------------===//
459
460void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
461 RewritePatternSet &patterns) {
462 MLIRContext *ctx = patterns.getContext();
463 for (Dialect *dialect : ctx->getLoadedDialects())
464 dialect->getCanonicalizationPatterns(patterns);
466 op.getCanonicalizationPatterns(patterns, ctx);
467}
468
469//===----------------------------------------------------------------------===//
470// ApplyConversionPatternsOp
471//===----------------------------------------------------------------------===//
472
473DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
476 MLIRContext *ctx = getContext();
477
478 // Instantiate the default type converter if a type converter builder is
479 // specified.
480 std::unique_ptr<TypeConverter> defaultTypeConverter;
481 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
482 getDefaultTypeConverter();
483 if (typeConverterBuilder)
484 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
485
486 // Configure conversion target.
487 ConversionTarget conversionTarget(*getContext());
488 if (getLegalOps())
489 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
490 conversionTarget.addLegalOp(
491 OperationName(cast<StringAttr>(attr).getValue(), ctx));
492 if (getIllegalOps())
493 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
494 conversionTarget.addIllegalOp(
495 OperationName(cast<StringAttr>(attr).getValue(), ctx));
496 if (getLegalDialects())
497 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
498 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
499 if (getIllegalDialects())
500 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
501 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
502
503 // Gather all specified patterns.
504 RewritePatternSet patterns(ctx);
505 // Need to keep the converters alive until after pattern application because
506 // the patterns take a reference to an object that would otherwise get out of
507 // scope.
509 if (!getPatterns().empty()) {
510 for (Operation &op : getPatterns().front()) {
511 auto descriptor =
512 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
513
514 // Check if this pattern set specifies a type converter.
515 std::unique_ptr<TypeConverter> typeConverter =
516 descriptor.getTypeConverter();
517 TypeConverter *converter = nullptr;
518 if (typeConverter) {
519 keepAliveConverters.emplace_back(std::move(typeConverter));
520 converter = keepAliveConverters.back().get();
521 } else {
522 // No type converter specified: Use the default type converter.
523 if (!defaultTypeConverter) {
525 << "pattern descriptor does not specify type "
526 "converter and apply_conversion_patterns op has "
527 "no default type converter";
528 diag.attachNote(op.getLoc()) << "pattern descriptor op";
529 return diag;
530 }
531 converter = defaultTypeConverter.get();
532 }
533
534 // Add descriptor-specific updates to the conversion target, which may
535 // depend on the final type converter. In structural converters, the
536 // legality of types dictates the dynamic legality of an operation.
537 descriptor.populateConversionTargetRules(*converter, conversionTarget);
538
539 descriptor.populatePatterns(*converter, patterns);
540 }
541 }
542
543 // Attach a tracking listener if handles should be preserved. We configure the
544 // listener to allow op replacements with different names, as conversion
545 // patterns typically replace ops with replacement ops that have a different
546 // name.
547 TrackingListenerConfig trackingConfig;
548 trackingConfig.requireMatchingReplacementOpName = false;
549 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
550 ConversionConfig conversionConfig;
551 if (getPreserveHandles())
552 conversionConfig.listener = &trackingListener;
553
554 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
555 for (Operation *target : state.getPayloadOps(getTarget())) {
556 // Make sure that this transform is not applied to itself. Modifying the
557 // transform IR while it is being interpreted is generally dangerous.
558 DiagnosedSilenceableFailure payloadCheck =
560 if (!payloadCheck.succeeded())
561 return payloadCheck;
562
563 LogicalResult status = failure();
564 if (getPartialConversion()) {
565 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
566 conversionConfig);
567 } else {
568 status = applyFullConversion(target, conversionTarget, frozenPatterns,
569 conversionConfig);
570 }
571
572 // Check dialect conversion state.
574 if (failed(status)) {
575 diag = emitSilenceableError() << "dialect conversion failed";
576 diag.attachNote(target->getLoc()) << "target op";
577 }
578
579 // Check tracking listener error state.
580 DiagnosedSilenceableFailure trackingFailure =
581 trackingListener.checkAndResetError();
582 if (!trackingFailure.succeeded()) {
583 if (diag.succeeded()) {
584 // Tracking failure is the only failure.
585 return trackingFailure;
586 }
587 diag.attachNote() << "tracking listener also failed: "
588 << trackingFailure.getMessage();
589 (void)trackingFailure.silence();
590 }
591
592 if (!diag.succeeded())
593 return diag;
594 }
595
597}
598
599LogicalResult transform::ApplyConversionPatternsOp::verify() {
600 if (getNumRegions() != 1 && getNumRegions() != 2)
601 return emitOpError() << "expected 1 or 2 regions";
602 if (!getPatterns().empty()) {
603 for (Operation &op : getPatterns().front()) {
604 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
606 emitOpError() << "expected pattern children ops to implement "
607 "ConversionPatternDescriptorOpInterface";
608 diag.attachNote(op.getLoc()) << "op without interface";
609 return diag;
610 }
611 }
612 }
613 if (getNumRegions() == 2) {
614 Region &typeConverterRegion = getRegion(1);
615 if (!llvm::hasSingleElement(typeConverterRegion.front()))
616 return emitOpError()
617 << "expected exactly one op in default type converter region";
618 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
619 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
620 maybeTypeConverter);
621 if (!typeConverterOp) {
623 << "expected default converter child op to "
624 "implement TypeConverterBuilderOpInterface";
625 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
626 return diag;
627 }
628 // Check default type converter type.
629 if (!getPatterns().empty()) {
630 for (Operation &op : getPatterns().front()) {
631 auto descriptor =
632 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
633 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
634 return failure();
635 }
636 }
637 }
638 return success();
639}
640
641void transform::ApplyConversionPatternsOp::getEffects(
643 if (!getPreserveHandles()) {
644 transform::consumesHandle(getTargetMutable(), effects);
645 } else {
646 transform::onlyReadsHandle(getTargetMutable(), effects);
647 }
649}
650
651void transform::ApplyConversionPatternsOp::build(
653 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
654 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
655 result.addOperands(target);
656
657 {
658 OpBuilder::InsertionGuard g(builder);
659 Region *region1 = result.addRegion();
660 builder.createBlock(region1);
661 if (patternsBodyBuilder)
662 patternsBodyBuilder(builder, result.location);
663 }
664 {
665 OpBuilder::InsertionGuard g(builder);
666 Region *region2 = result.addRegion();
667 builder.createBlock(region2);
668 if (typeConverterBodyBuilder)
669 typeConverterBodyBuilder(builder, result.location);
670 }
671}
672
673//===----------------------------------------------------------------------===//
674// ApplyToLLVMConversionPatternsOp
675//===----------------------------------------------------------------------===//
676
677void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
678 TypeConverter &typeConverter, RewritePatternSet &patterns) {
679 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
680 assert(dialect && "expected that dialect is loaded");
681 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
682 // ConversionTarget is currently ignored because the enclosing
683 // apply_conversion_patterns op sets up its own ConversionTarget.
685 iface->populateConvertToLLVMConversionPatterns(
686 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
687}
688
689LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
690 transform::TypeConverterBuilderOpInterface builder) {
691 if (builder.getTypeConverterType() != "LLVMTypeConverter")
692 return emitOpError("expected LLVMTypeConverter");
693 return success();
694}
695
696LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
697 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
698 if (!dialect)
699 return emitOpError("unknown dialect or dialect not loaded: ")
700 << getDialectName();
701 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
702 if (!iface)
703 return emitOpError(
704 "dialect does not implement ConvertToLLVMPatternInterface or "
705 "extension was not loaded: ")
706 << getDialectName();
707 return success();
708}
709
710//===----------------------------------------------------------------------===//
711// ApplyLoopInvariantCodeMotionOp
712//===----------------------------------------------------------------------===//
713
715transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
716 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
719 // Currently, LICM does not remove operations, so we don't need tracking.
720 // If this ever changes, add a LICM entry point that takes a rewriter.
723}
724
725void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
727 transform::onlyReadsHandle(getTargetMutable(), effects);
729}
730
731//===----------------------------------------------------------------------===//
732// ApplyRegisteredPassOp
733//===----------------------------------------------------------------------===//
734
735void transform::ApplyRegisteredPassOp::getEffects(
737 consumesHandle(getTargetMutable(), effects);
738 onlyReadsHandle(getDynamicOptionsMutable(), effects);
739 producesHandle(getOperation()->getOpResults(), effects);
740 modifiesPayload(effects);
741}
742
744transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
747 // Obtain a single options-string to pass to the pass(-pipeline) from options
748 // passed in as a dictionary of keys mapping to values which are either
749 // attributes or param-operands pointing to attributes.
750 OperandRange dynamicOptions = getDynamicOptions();
751
752 std::string options;
753 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
754
755 // A helper to convert an option's attribute value into a corresponding
756 // string representation, with the ability to obtain the attr(s) from a param.
757 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
758 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
759 // The corresponding value attribute(s) is/are passed in via a param.
760 // Obtain the param-operand via its specified index.
761 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
762 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
763 "the number of ParamOperandAttrs in the options DictionaryAttr"
764 "should be the same as the number of options passed as params");
765 ArrayRef<Attribute> attrsAssociatedToParam =
766 state.getParams(dynamicOptions[dynamicOptionIdx]);
767 // Recursive so as to append all attrs associated to the param.
768 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
769 ",");
770 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
771 // Recursive so as to append all nested attrs of the array.
772 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
773 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
774 // Convert to unquoted string.
775 optionsStream << strAttr.getValue().str();
776 } else {
777 // For all other attributes, ask the attr to print itself (without type).
778 valueAttr.print(optionsStream, /*elideType=*/true);
779 }
780 };
781
782 // Convert the options DictionaryAttr into a single string.
783 llvm::interleave(
784 getOptions(), optionsStream,
785 [&](auto namedAttribute) {
786 optionsStream << namedAttribute.getName().str(); // Append the key.
787 optionsStream << "="; // And the key-value separator.
788 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
789 },
790 " ");
791 optionsStream.flush();
792
793 // Get pass or pass pipeline from registry.
794 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
795 if (!info)
796 info = PassInfo::lookup(getPassName());
797 if (!info)
798 return emitDefiniteFailure()
799 << "unknown pass or pass pipeline: " << getPassName();
800
801 // Create pass manager and add the pass or pass pipeline.
803 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
804 emitError(msg);
805 return failure();
806 }))) {
807 return emitDefiniteFailure()
808 << "failed to add pass or pass pipeline to pipeline: "
809 << getPassName();
810 }
811
812 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
813 for (Operation *target : targets) {
814 // Make sure that this transform is not applied to itself. Modifying the
815 // transform IR while it is being interpreted is generally dangerous. Even
816 // more so when applying passes because they may perform a wide range of IR
817 // modifications.
818 DiagnosedSilenceableFailure payloadCheck =
820 if (!payloadCheck.succeeded())
821 return payloadCheck;
822
823 // Run the pass or pass pipeline on the current target operation.
824 if (failed(pm.run(target))) {
825 auto diag = emitSilenceableError() << "pass pipeline failed";
826 diag.attachNote(target->getLoc()) << "target op";
827 return diag;
828 }
829 }
830
831 // The applied pass will have directly modified the payload IR(s).
832 results.set(llvm::cast<OpResult>(getResult()), targets);
834}
835
837 OpAsmParser &parser, DictionaryAttr &options,
839 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
840 SmallVector<NamedAttribute> keyValuePairs;
841 size_t dynamicOptionsIdx = 0;
842
843 // Helper for allowing parsing of option values which can be of the form:
844 // - a normal attribute
845 // - an operand (which would be converted to an attr referring to the operand)
846 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
847 std::function<ParseResult(Attribute &)> parseValue =
848 [&](Attribute &valueAttr) -> ParseResult {
849 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
850 if (succeeded(parser.parseOptionalLSquare())) {
852
853 // Recursively parse the array's elements, which might be operands.
854 if (parser.parseCommaSeparatedList(
856 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
857 " in options dictionary") ||
858 parser.parseRSquare())
859 return failure(); // NB: Attempted parse should've output error message.
860
861 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
862
863 return success();
864 }
865
866 // Parse the value, which can be either an attribute or an operand.
867 OptionalParseResult parsedValueAttr =
868 parser.parseOptionalAttribute(valueAttr);
869 if (!parsedValueAttr.has_value()) {
871 ParseResult parsedOperand = parser.parseOperand(operand);
872 if (failed(parsedOperand))
873 return failure(); // NB: Attempted parse should've output error message.
874 // To make use of the operand, we need to store it in the options dict.
875 // As SSA-values cannot occur in attributes, what we do instead is store
876 // an attribute in its place that contains the index of the param-operand,
877 // so that an attr-value associated to the param can be resolved later on.
878 dynamicOptions.push_back(operand);
879 auto wrappedIndex = IntegerAttr::get(
880 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
881 valueAttr =
882 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
883 } else if (failed(parsedValueAttr.value())) {
884 return failure(); // NB: Attempted parse should have output error message.
885 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
886 return parser.emitError(parser.getCurrentLocation())
887 << "the param_operand attribute is a marker reserved for "
888 << "indicating a value will be passed via params and is only used "
889 << "in the generic print format";
890 }
891
892 return success();
893 };
894
895 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
896 // string and `value` looks like either an attribute or an operand-in-an-attr.
897 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
898 std::string key;
899 Attribute valueAttr;
900
901 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
902 return parser.emitError(parser.getCurrentLocation())
903 << "expected key to either be an identifier or a string";
904
905 if (failed(parser.parseEqual()))
906 return parser.emitError(parser.getCurrentLocation())
907 << "expected '=' after key in key-value pair";
908
909 if (failed(parseValue(valueAttr)))
910 return parser.emitError(parser.getCurrentLocation())
911 << "expected a valid attribute or operand as value associated "
912 << "to key '" << key << "'";
913
914 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
915
916 return success();
917 };
918
921 " in options dictionary"))
922 return failure(); // NB: Attempted parse should have output error message.
923
924 if (DictionaryAttr::findDuplicate(
925 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
926 .has_value())
927 return parser.emitError(parser.getCurrentLocation())
928 << "duplicate keys found in options dictionary";
929
930 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
931
932 return success();
933}
934
936 Operation *op,
937 DictionaryAttr options,
938 ValueRange dynamicOptions) {
939 if (options.empty())
940 return;
941
942 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
943 if (auto paramOperandAttr =
944 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
945 // Resolve index of param-operand to its actual SSA-value and print that.
946 printer.printOperand(
947 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
948 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
949 // This case is so that ArrayAttr-contained operands are pretty-printed.
950 printer << "[";
951 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
952 printer << "]";
953 } else {
954 printer.printAttribute(valueAttr);
955 }
956 };
957
958 printer << "{";
959 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
960 printer << namedAttribute.getName();
961 printer << " = ";
962 printOptionValue(namedAttribute.getValue());
963 });
964 printer << "}";
965}
966
967LogicalResult transform::ApplyRegisteredPassOp::verify() {
968 // Check that there is a one-to-one correspondence between param operands
969 // and references to dynamic options in the options dictionary.
970
971 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
972
973 // Helper for option values to mark seen operands as having been seen (once).
974 std::function<LogicalResult(Attribute)> checkOptionValue =
975 [&](Attribute valueAttr) -> LogicalResult {
976 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
977 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
978 if (dynamicOptionIdx < 0 ||
979 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
980 return emitOpError()
981 << "dynamic option index " << dynamicOptionIdx
982 << " is out of bounds for the number of dynamic options: "
983 << dynamicOptions.size();
984 if (dynamicOptions[dynamicOptionIdx] == nullptr)
985 return emitOpError() << "dynamic option index " << dynamicOptionIdx
986 << " is already used in options";
987 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
988 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
989 // Recurse into ArrayAttrs as they may contain references to operands.
990 for (auto eltAttr : arrayAttr)
991 if (failed(checkOptionValue(eltAttr)))
992 return failure();
993 }
994 return success();
995 };
996
997 for (NamedAttribute namedAttr : getOptions())
998 if (failed(checkOptionValue(namedAttr.getValue())))
999 return failure();
1000
1001 // All dynamicOptions-params seen in the dict will have been set to null.
1002 for (Value dynamicOption : dynamicOptions)
1003 if (dynamicOption)
1004 return emitOpError() << "a param operand does not have a corresponding "
1005 << "param_operand attr in the options dict";
1006
1007 return success();
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// CastOp
1012//===----------------------------------------------------------------------===//
1013
1015transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1016 Operation *target, ApplyToEachResultList &results,
1018 results.push_back(target);
1020}
1021
1022void transform::CastOp::getEffects(
1024 onlyReadsPayload(effects);
1025 onlyReadsHandle(getInputMutable(), effects);
1026 producesHandle(getOperation()->getOpResults(), effects);
1027}
1028
1029bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1030 assert(inputs.size() == 1 && "expected one input");
1031 assert(outputs.size() == 1 && "expected one output");
1032 return llvm::all_of(
1033 std::initializer_list<Type>{inputs.front(), outputs.front()},
1034 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// CollectMatchingOp
1039//===----------------------------------------------------------------------===//
1040
1041/// Applies matcher operations from the given `block` using
1042/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1043/// accordingly. If any of the matcher produces a silenceable failure, discards
1044/// it (printing the content to the debug output stream) and returns failure. If
1045/// any of the matchers produces a definite failure, reports it and returns
1046/// failure. If all matchers in the block succeed, populates `mappings` with the
1047/// payload entities associated with the block terminator operands. Note that
1048/// `mappings` will be cleared before that.
1051 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1054 assert(block.getParent() && "cannot match using a detached block");
1055 auto matchScope = state.make_region_scope(*block.getParent());
1056 if (failed(
1057 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1059
1060 for (Operation &match : block.without_terminator()) {
1061 if (!isa<transform::MatchOpInterface>(match)) {
1062 return emitDefiniteFailure(match.getLoc())
1063 << "expected operations in the match part to "
1064 "implement MatchOpInterface";
1065 }
1067 state.applyTransform(cast<transform::TransformOpInterface>(match));
1068 if (diag.succeeded())
1069 continue;
1070
1071 return diag;
1072 }
1073
1074 // Remember the values mapped to the terminator operands so we can
1075 // forward them to the action.
1076 ValueRange yieldedValues = block.getTerminator()->getOperands();
1077 // Our contract with the caller is that the mappings will contain only the
1078 // newly mapped values, clear the rest.
1079 mappings.clear();
1080 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1082}
1083
1084/// Returns `true` if both types implement one of the interfaces provided as
1085/// template parameters.
1086template <typename... Tys>
1087static bool implementSameInterface(Type t1, Type t2) {
1088 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1089}
1090
1091/// Returns `true` if both types implement one of the transform dialect
1092/// interfaces.
1094 return implementSameInterface<transform::TransformHandleTypeInterface,
1095 transform::TransformParamTypeInterface,
1096 transform::TransformValueHandleTypeInterface>(
1097 t1, t2);
1098}
1099
1100//===----------------------------------------------------------------------===//
1101// CollectMatchingOp
1102//===----------------------------------------------------------------------===//
1103
1105transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1109 getOperation(), getMatcher());
1110 if (matcher.isExternal()) {
1111 return emitDefiniteFailure()
1112 << "unresolved external symbol " << getMatcher();
1113 }
1114
1116 rawResults.resize(getOperation()->getNumResults());
1117 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1118 for (Operation *root : state.getPayloadOps(getRoot())) {
1119 WalkResult walkResult = root->walk([&](Operation *op) {
1120 LDBG(DEBUG_TYPE_MATCHER, 1)
1121 << "matching "
1122 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1123 << " @" << op;
1124
1125 // Try matching.
1127 SmallVector<transform::MappedValue> inputMapping({op});
1129 matcher.getFunctionBody().front(),
1130 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1131 mappings);
1132 if (diag.isDefiniteFailure())
1133 return WalkResult::interrupt();
1134 if (diag.isSilenceableFailure()) {
1135 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1136 << " failed: " << diag.getMessage();
1137 return WalkResult::advance();
1138 }
1139
1140 // If succeeded, collect results.
1141 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1142 if (mapping.size() != 1) {
1143 maybeFailure.emplace(emitSilenceableError()
1144 << "result #" << i << ", associated with "
1145 << mapping.size()
1146 << " payload objects, expected 1");
1147 return WalkResult::interrupt();
1148 }
1149 rawResults[i].push_back(mapping[0]);
1150 }
1151 return WalkResult::advance();
1152 });
1153 if (walkResult.wasInterrupted())
1154 return std::move(*maybeFailure);
1155 assert(!maybeFailure && "failure set but the walk was not interrupted");
1156
1157 for (auto &&[opResult, rawResult] :
1158 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1159 results.setMappedValues(opResult, rawResult);
1160 }
1161 }
1163}
1164
1165void transform::CollectMatchingOp::getEffects(
1167 onlyReadsHandle(getRootMutable(), effects);
1168 producesHandle(getOperation()->getOpResults(), effects);
1169 onlyReadsPayload(effects);
1170}
1171
1172LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1173 SymbolTableCollection &symbolTable) {
1174 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1175 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1176 if (!matcherSymbol ||
1177 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1178 return emitError() << "unresolved matcher symbol " << getMatcher();
1179
1180 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1181 if (argumentTypes.size() != 1 ||
1182 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1183 return emitError()
1184 << "expected the matcher to take one operation handle argument";
1185 }
1186 if (!matcherSymbol.getArgAttr(
1187 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1188 return emitError() << "expected the matcher argument to be marked readonly";
1189 }
1190
1191 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1192 if (resultTypes.size() != getOperation()->getNumResults()) {
1193 return emitError()
1194 << "expected the matcher to yield as many values as op has results ("
1195 << getOperation()->getNumResults() << "), got "
1196 << resultTypes.size();
1197 }
1198
1199 for (auto &&[i, matcherType, resultType] :
1200 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1201 if (implementSameTransformInterface(matcherType, resultType))
1202 continue;
1203
1204 return emitError()
1205 << "mismatching type interfaces for matcher result and op result #"
1206 << i;
1207 }
1208
1209 return success();
1210}
1211
1212//===----------------------------------------------------------------------===//
1213// ForeachMatchOp
1214//===----------------------------------------------------------------------===//
1215
1216// This is fine because nothing is actually consumed by this op.
1217bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1218
1220transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1224 matchActionPairs;
1225 matchActionPairs.reserve(getMatchers().size());
1226 SymbolTableCollection symbolTable;
1227 for (auto &&[matcher, action] :
1228 llvm::zip_equal(getMatchers(), getActions())) {
1229 auto matcherSymbol =
1230 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1231 getOperation(), cast<SymbolRefAttr>(matcher));
1232 auto actionSymbol =
1233 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1234 getOperation(), cast<SymbolRefAttr>(action));
1235 assert(matcherSymbol && actionSymbol &&
1236 "unresolved symbols not caught by the verifier");
1237
1238 if (matcherSymbol.isExternal())
1239 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1240 if (actionSymbol.isExternal())
1241 return emitDefiniteFailure() << "unresolved external symbol " << action;
1242
1243 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1244 }
1245
1246 DiagnosedSilenceableFailure overallDiag =
1248
1249 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1250 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1251 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1252 // Explicitly add the mapping for the first block argument (the op being
1253 // matched).
1254 matchInputMapping.emplace_back();
1256 getForwardedInputs(), state);
1257 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1258 actionResultMapping.resize(getForwardedOutputs().size());
1259
1260 for (Operation *root : state.getPayloadOps(getRoot())) {
1261 WalkResult walkResult = root->walk([&](Operation *op) {
1262 // If getRestrictRoot is not present, skip over the root op itself so we
1263 // don't invalidate it.
1264 if (!getRestrictRoot() && op == root)
1265 return WalkResult::advance();
1266
1267 LDBG(DEBUG_TYPE_MATCHER, 1)
1268 << "matching "
1269 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1270 << " @" << op;
1271
1272 firstMatchArgument.clear();
1273 firstMatchArgument.push_back(op);
1274
1275 // Try all the match/action pairs until the first successful match.
1276 for (auto [matcher, action] : matchActionPairs) {
1278 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1279 state, matchOutputMapping);
1280 if (diag.isDefiniteFailure())
1281 return WalkResult::interrupt();
1282 if (diag.isSilenceableFailure()) {
1283 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1284 << " failed: " << diag.getMessage();
1285 continue;
1286 }
1287
1288 auto scope = state.make_region_scope(action.getFunctionBody());
1289 if (failed(state.mapBlockArguments(
1290 action.getFunctionBody().front().getArguments(),
1291 matchOutputMapping))) {
1292 return WalkResult::interrupt();
1293 }
1294
1295 for (Operation &transform :
1296 action.getFunctionBody().front().without_terminator()) {
1298 state.applyTransform(cast<TransformOpInterface>(transform));
1299 if (result.isDefiniteFailure())
1300 return WalkResult::interrupt();
1301 if (result.isSilenceableFailure()) {
1302 if (overallDiag.succeeded()) {
1303 overallDiag = emitSilenceableError() << "actions failed";
1304 }
1305 overallDiag.attachNote(action->getLoc())
1306 << "failed action: " << result.getMessage();
1307 overallDiag.attachNote(op->getLoc())
1308 << "when applied to this matching payload";
1309 (void)result.silence();
1310 continue;
1311 }
1312 }
1313 if (failed(detail::appendValueMappings(
1314 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1315 action.getFunctionBody().front().getTerminator()->getOperands(),
1316 state, getFlattenResults()))) {
1318 << "action @" << action.getName()
1319 << " has results associated with multiple payload entities, "
1320 "but flattening was not requested";
1321 return WalkResult::interrupt();
1322 }
1323 break;
1324 }
1325 return WalkResult::advance();
1326 });
1327 if (walkResult.wasInterrupted())
1329 }
1330
1331 // The root operation should not have been affected, so we can just reassign
1332 // the payload to the result. Note that we need to consume the root handle to
1333 // make sure any handles to operations inside, that could have been affected
1334 // by actions, are invalidated.
1335 results.set(llvm::cast<OpResult>(getUpdated()),
1336 state.getPayloadOps(getRoot()));
1337 for (auto &&[result, mapping] :
1338 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1339 results.setMappedValues(result, mapping);
1340 }
1341 return overallDiag;
1342}
1343
1344void transform::ForeachMatchOp::getAsmResultNames(
1345 OpAsmSetValueNameFn setNameFn) {
1346 setNameFn(getUpdated(), "updated_root");
1347 for (Value v : getForwardedOutputs()) {
1348 setNameFn(v, "yielded");
1349 }
1350}
1351
1352void transform::ForeachMatchOp::getEffects(
1354 // Bail if invalid.
1355 if (getOperation()->getNumOperands() < 1 ||
1356 getOperation()->getNumResults() < 1) {
1357 return modifiesPayload(effects);
1358 }
1359
1360 consumesHandle(getRootMutable(), effects);
1361 onlyReadsHandle(getForwardedInputsMutable(), effects);
1362 producesHandle(getOperation()->getOpResults(), effects);
1363 modifiesPayload(effects);
1364}
1365
1366/// Parses the comma-separated list of symbol reference pairs of the format
1367/// `@matcher -> @action`.
1368static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1370 ArrayAttr &actions) {
1371 StringAttr matcher;
1372 StringAttr action;
1373 SmallVector<Attribute> matcherList;
1374 SmallVector<Attribute> actionList;
1375 do {
1376 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1377 parser.parseSymbolName(action)) {
1378 return failure();
1379 }
1380 matcherList.push_back(SymbolRefAttr::get(matcher));
1381 actionList.push_back(SymbolRefAttr::get(action));
1382 } while (parser.parseOptionalComma().succeeded());
1383
1384 matchers = parser.getBuilder().getArrayAttr(matcherList);
1385 actions = parser.getBuilder().getArrayAttr(actionList);
1386 return success();
1387}
1388
1389/// Prints the comma-separated list of symbol reference pairs of the format
1390/// `@matcher -> @action`.
1392 ArrayAttr matchers, ArrayAttr actions) {
1393 printer.increaseIndent();
1394 printer.increaseIndent();
1395 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1396 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1397 printer.printNewline();
1398 printer << cast<SymbolRefAttr>(matcher) << " -> "
1399 << cast<SymbolRefAttr>(action);
1400 if (idx != matchers.size() - 1)
1401 printer << ", ";
1402 }
1403 printer.decreaseIndent();
1404 printer.decreaseIndent();
1405}
1406
1407LogicalResult transform::ForeachMatchOp::verify() {
1408 if (getMatchers().size() != getActions().size())
1409 return emitOpError() << "expected the same number of matchers and actions";
1410 if (getMatchers().empty())
1411 return emitOpError() << "expected at least one match/action pair";
1412
1414 for (Attribute name : getMatchers()) {
1415 if (matcherNames.insert(name).second)
1416 continue;
1417 emitWarning() << "matcher " << name
1418 << " is used more than once, only the first match will apply";
1419 }
1420
1421 return success();
1422}
1423
1424/// Checks that the attributes of the function-like operation have correct
1425/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1426/// annotations being present even if they can be inferred from the body.
1428verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1429 bool alsoVerifyInternal = false) {
1430 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1431 llvm::SmallDenseSet<unsigned> consumedArguments;
1432 if (!op.isExternal()) {
1433 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1434 consumedArguments);
1435 }
1436 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1437 bool isConsumed =
1438 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1439 nullptr;
1440 bool isReadOnly =
1441 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1442 nullptr;
1443 if (isConsumed && isReadOnly) {
1444 return transformOp.emitSilenceableError()
1445 << "argument #" << i << " cannot be both readonly and consumed";
1446 }
1447 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1448 return transformOp.emitSilenceableError()
1449 << "must provide consumed/readonly status for arguments of "
1450 "external or called ops";
1451 }
1452 if (op.isExternal())
1453 continue;
1454
1455 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1456 return transformOp.emitSilenceableError()
1457 << "argument #" << i
1458 << " is consumed in the body but is not marked as such";
1459 }
1460 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1461 // Cannot use op.emitWarning() here as it would attempt to verify the op
1462 // before printing, resulting in infinite recursion.
1463 emitWarning(op->getLoc())
1464 << "op argument #" << i
1465 << " is not consumed in the body but is marked as consumed";
1466 }
1467 }
1469}
1470
1471LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1472 SymbolTableCollection &symbolTable) {
1473 assert(getMatchers().size() == getActions().size());
1474 auto consumedAttr =
1475 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1476 for (auto &&[matcher, action] :
1477 llvm::zip_equal(getMatchers(), getActions())) {
1478 // Presence and typing.
1479 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1480 symbolTable.lookupNearestSymbolFrom(getOperation(),
1481 cast<SymbolRefAttr>(matcher)));
1482 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1483 symbolTable.lookupNearestSymbolFrom(getOperation(),
1484 cast<SymbolRefAttr>(action)));
1485 if (!matcherSymbol ||
1486 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1487 return emitError() << "unresolved matcher symbol " << matcher;
1488 if (!actionSymbol ||
1489 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1490 return emitError() << "unresolved action symbol " << action;
1491
1493 /*emitWarnings=*/false,
1494 /*alsoVerifyInternal=*/true)
1495 .checkAndReport())) {
1496 return failure();
1497 }
1499 /*emitWarnings=*/false,
1500 /*alsoVerifyInternal=*/true)
1501 .checkAndReport())) {
1502 return failure();
1503 }
1504
1505 // Input -> matcher forwarding.
1506 TypeRange operandTypes = getOperandTypes();
1507 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1508 if (operandTypes.size() != matcherArguments.size()) {
1510 emitError() << "the number of operands (" << operandTypes.size()
1511 << ") doesn't match the number of matcher arguments ("
1512 << matcherArguments.size() << ") for " << matcher;
1513 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1514 return diag;
1515 }
1516 for (auto &&[i, operand, argument] :
1517 llvm::enumerate(operandTypes, matcherArguments)) {
1518 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1520 emitOpError()
1521 << "does not expect matcher symbol to consume its operand #" << i;
1522 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1523 return diag;
1524 }
1525
1526 if (implementSameTransformInterface(operand, argument))
1527 continue;
1528
1530 emitError()
1531 << "mismatching type interfaces for operand and matcher argument #"
1532 << i << " of matcher " << matcher;
1533 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1534 return diag;
1535 }
1536
1537 // Matcher -> action forwarding.
1538 TypeRange matcherResults = matcherSymbol.getResultTypes();
1539 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1540 if (matcherResults.size() != actionArguments.size()) {
1541 return emitError() << "mismatching number of matcher results and "
1542 "action arguments between "
1543 << matcher << " (" << matcherResults.size() << ") and "
1544 << action << " (" << actionArguments.size() << ")";
1545 }
1546 for (auto &&[i, matcherType, actionType] :
1547 llvm::enumerate(matcherResults, actionArguments)) {
1548 if (implementSameTransformInterface(matcherType, actionType))
1549 continue;
1550
1551 return emitError() << "mismatching type interfaces for matcher result "
1552 "and action argument #"
1553 << i << "of matcher " << matcher << " and action "
1554 << action;
1555 }
1556
1557 // Action -> result forwarding.
1558 TypeRange actionResults = actionSymbol.getResultTypes();
1559 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1560 if (actionResults.size() != resultTypes.size()) {
1562 emitError() << "the number of action results ("
1563 << actionResults.size() << ") for " << action
1564 << " doesn't match the number of extra op results ("
1565 << resultTypes.size() << ")";
1566 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1567 return diag;
1568 }
1569 for (auto &&[i, resultType, actionType] :
1570 llvm::enumerate(resultTypes, actionResults)) {
1571 if (implementSameTransformInterface(resultType, actionType))
1572 continue;
1573
1575 emitError() << "mismatching type interfaces for action result #" << i
1576 << " of action " << action << " and op result";
1577 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1578 return diag;
1579 }
1580 }
1581 return success();
1582}
1583
1584//===----------------------------------------------------------------------===//
1585// ForeachOp
1586//===----------------------------------------------------------------------===//
1587
1589transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1592 // We store the payloads before executing the body as ops may be removed from
1593 // the mapping by the TrackingRewriter while iteration is in progress.
1595 detail::prepareValueMappings(payloads, getTargets(), state);
1596 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1597 bool withZipShortest = getWithZipShortest();
1598
1599 // In case of `zip_shortest`, set the number of iterations to the
1600 // smallest payload in the targets.
1601 if (withZipShortest) {
1602 numIterations =
1603 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1604 const SmallVector<MappedValue> &b) {
1605 return a.size() < b.size();
1606 })->size();
1607
1608 for (auto &payload : payloads)
1609 payload.resize(numIterations);
1610 }
1611
1612 // As we will be "zipping" over them, check all payloads have the same size.
1613 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1614 // when true.
1615 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1616 argIdx++) {
1617 if (payloads[argIdx].size() != numIterations) {
1618 return emitSilenceableError()
1619 << "prior targets' payload size (" << numIterations
1620 << ") differs from payload size (" << payloads[argIdx].size()
1621 << ") of target " << getTargets()[argIdx];
1622 }
1623 }
1624
1625 // Start iterating, indexing into payloads to obtain the right arguments to
1626 // call the body with - each slice of payloads at the same argument index
1627 // corresponding to a tuple to use as the body's block arguments.
1628 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1629 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1630 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1631 auto scope = state.make_region_scope(getBody());
1632 // Set up arguments to the region's block.
1633 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1634 MappedValue argument = payloads[argIdx][iterIdx];
1635 // Note that each blockArg's handle gets associated with just a single
1636 // element from the corresponding target's payload.
1637 if (failed(state.mapBlockArgument(blockArg, {argument})))
1639 }
1640
1641 // Execute loop body.
1642 for (Operation &transform : getBody().front().without_terminator()) {
1644 llvm::cast<transform::TransformOpInterface>(transform));
1645 if (!result.succeeded())
1646 return result;
1647 }
1648
1649 // Append yielded payloads to corresponding results from prior iterations.
1650 OperandRange yieldOperands = getYieldOp().getOperands();
1651 for (auto &&[result, yieldOperand, resTuple] :
1652 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1653 // NB: each iteration we add any number of ops/vals/params to a result.
1654 if (isa<TransformHandleTypeInterface>(result.getType()))
1655 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1656 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1657 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1658 else if (isa<TransformParamTypeInterface>(result.getType()))
1659 llvm::append_range(resTuple, state.getParams(yieldOperand));
1660 else
1661 assert(false && "unhandled handle type");
1662 }
1663
1664 // Associate the accumulated result payloads to the op's actual results.
1665 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1666 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1667
1669}
1670
1671void transform::ForeachOp::getEffects(
1673 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1674 // arity errors, this method might get called before/in absence of `verify()`.
1675 for (auto &&[target, blockArg] :
1676 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1677 BlockArgument blockArgument = blockArg;
1678 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1679 return isHandleConsumed(blockArgument,
1680 cast<TransformOpInterface>(&op));
1681 })) {
1682 consumesHandle(target, effects);
1683 } else {
1684 onlyReadsHandle(target, effects);
1685 }
1686 }
1687
1688 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1689 return doesModifyPayload(cast<TransformOpInterface>(&op));
1690 })) {
1691 modifiesPayload(effects);
1692 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1693 return doesReadPayload(cast<TransformOpInterface>(&op));
1694 })) {
1695 onlyReadsPayload(effects);
1696 }
1697
1698 producesHandle(getOperation()->getOpResults(), effects);
1699}
1700
1701void transform::ForeachOp::getSuccessorRegions(
1703 Region *bodyRegion = &getBody();
1704 if (point.isParent()) {
1705 regions.emplace_back(bodyRegion);
1706 return;
1707 }
1708
1709 // Branch back to the region or the parent.
1710 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
1711 &getBody() &&
1712 "unexpected region index");
1713 regions.emplace_back(bodyRegion);
1714 regions.push_back(RegionSuccessor::parent());
1715}
1716
1717ValueRange transform::ForeachOp::getSuccessorInputs(RegionSuccessor successor) {
1718 return successor.isParent() ? ValueRange(getResults())
1719 : ValueRange(getBody().getArguments());
1720}
1721
1723transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1724 // Each block argument handle is mapped to a subset (one op to be precise)
1725 // of the payload of the corresponding `targets` operand of ForeachOp.
1726 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1727 return getOperation()->getOperands();
1728}
1729
1730transform::YieldOp transform::ForeachOp::getYieldOp() {
1731 return cast<transform::YieldOp>(getBody().front().getTerminator());
1732}
1733
1734LogicalResult transform::ForeachOp::verify() {
1735 for (auto [targetOpt, bodyArgOpt] :
1736 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1737 if (!targetOpt || !bodyArgOpt)
1738 return emitOpError() << "expects the same number of targets as the body "
1739 "has block arguments";
1740 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1741 return emitOpError(
1742 "expects co-indexed targets and the body's "
1743 "block arguments to have the same op/value/param type");
1744 }
1745
1746 for (auto [resultOpt, yieldOperandOpt] :
1747 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1748 if (!resultOpt || !yieldOperandOpt)
1749 return emitOpError() << "expects the same number of results as the "
1750 "yield terminator has operands";
1751 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1752 return emitOpError("expects co-indexed results and yield "
1753 "operands to have the same op/value/param type");
1754 }
1755
1756 return success();
1757}
1758
1759//===----------------------------------------------------------------------===//
1760// GetParentOp
1761//===----------------------------------------------------------------------===//
1762
1764transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1768 DenseSet<Operation *> resultSet;
1769 for (Operation *target : state.getPayloadOps(getTarget())) {
1770 Operation *parent = target;
1771 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1772 parent = parent->getParentOp();
1773 while (parent) {
1774 bool checkIsolatedFromAbove =
1775 !getIsolatedFromAbove() ||
1777 bool checkOpName = !getOpName().has_value() ||
1778 parent->getName().getStringRef() == *getOpName();
1779 if (checkIsolatedFromAbove && checkOpName)
1780 break;
1781 parent = parent->getParentOp();
1782 }
1783 if (!parent) {
1784 if (getAllowEmptyResults()) {
1785 results.set(llvm::cast<OpResult>(getResult()), parents);
1787 }
1789 emitSilenceableError()
1790 << "could not find a parent op that matches all requirements";
1791 diag.attachNote(target->getLoc()) << "target op";
1792 return diag;
1793 }
1794 }
1795 if (getDeduplicate()) {
1796 if (resultSet.insert(parent).second)
1797 parents.push_back(parent);
1798 } else {
1799 parents.push_back(parent);
1800 }
1801 }
1802 results.set(llvm::cast<OpResult>(getResult()), parents);
1804}
1805
1806//===----------------------------------------------------------------------===//
1807// GetConsumersOfResult
1808//===----------------------------------------------------------------------===//
1809
1811transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1814 int64_t resultNumber = getResultNumber();
1815 auto payloadOps = state.getPayloadOps(getTarget());
1816 if (std::empty(payloadOps)) {
1817 results.set(cast<OpResult>(getResult()), {});
1819 }
1820 if (!llvm::hasSingleElement(payloadOps))
1821 return emitDefiniteFailure()
1822 << "handle must be mapped to exactly one payload op";
1823
1824 Operation *target = *payloadOps.begin();
1825 if (target->getNumResults() <= resultNumber)
1826 return emitDefiniteFailure() << "result number overflow";
1827 results.set(llvm::cast<OpResult>(getResult()),
1828 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1830}
1831
1832//===----------------------------------------------------------------------===//
1833// GetDefiningOp
1834//===----------------------------------------------------------------------===//
1835
1837transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1840 SmallVector<Operation *> definingOps;
1841 for (Value v : state.getPayloadValues(getTarget())) {
1842 if (llvm::isa<BlockArgument>(v)) {
1844 emitSilenceableError() << "cannot get defining op of block argument";
1845 diag.attachNote(v.getLoc()) << "target value";
1846 return diag;
1847 }
1848 definingOps.push_back(v.getDefiningOp());
1849 }
1850 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1852}
1853
1854//===----------------------------------------------------------------------===//
1855// GetProducerOfOperand
1856//===----------------------------------------------------------------------===//
1857
1859transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1862 int64_t operandNumber = getOperandNumber();
1863 SmallVector<Operation *> producers;
1864 for (Operation *target : state.getPayloadOps(getTarget())) {
1865 Operation *producer =
1866 target->getNumOperands() <= operandNumber
1867 ? nullptr
1868 : target->getOperand(operandNumber).getDefiningOp();
1869 if (!producer) {
1871 emitSilenceableError()
1872 << "could not find a producer for operand number: " << operandNumber
1873 << " of " << *target;
1874 diag.attachNote(target->getLoc()) << "target op";
1875 return diag;
1876 }
1877 producers.push_back(producer);
1878 }
1879 results.set(llvm::cast<OpResult>(getResult()), producers);
1881}
1882
1883//===----------------------------------------------------------------------===//
1884// GetOperandOp
1885//===----------------------------------------------------------------------===//
1886
1888transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1891 SmallVector<Value> operands;
1892 for (Operation *target : state.getPayloadOps(getTarget())) {
1893 SmallVector<int64_t> operandPositions;
1895 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1896 target->getNumOperands(), operandPositions);
1897 if (diag.isSilenceableFailure()) {
1898 diag.attachNote(target->getLoc())
1899 << "while considering positions of this payload operation";
1900 return diag;
1901 }
1902 llvm::append_range(operands,
1903 llvm::map_range(operandPositions, [&](int64_t pos) {
1904 return target->getOperand(pos);
1905 }));
1906 }
1907 results.setValues(cast<OpResult>(getResult()), operands);
1909}
1910
1911LogicalResult transform::GetOperandOp::verify() {
1912 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1913 getIsInverted(), getIsAll());
1914}
1915
1916//===----------------------------------------------------------------------===//
1917// GetResultOp
1918//===----------------------------------------------------------------------===//
1919
1921transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1924 SmallVector<Value> opResults;
1925 for (Operation *target : state.getPayloadOps(getTarget())) {
1926 SmallVector<int64_t> resultPositions;
1928 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1929 target->getNumResults(), resultPositions);
1930 if (diag.isSilenceableFailure()) {
1931 diag.attachNote(target->getLoc())
1932 << "while considering positions of this payload operation";
1933 return diag;
1934 }
1935 llvm::append_range(opResults,
1936 llvm::map_range(resultPositions, [&](int64_t pos) {
1937 return target->getResult(pos);
1938 }));
1939 }
1940 results.setValues(cast<OpResult>(getResult()), opResults);
1942}
1943
1944LogicalResult transform::GetResultOp::verify() {
1945 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1946 getIsInverted(), getIsAll());
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// GetTypeOp
1951//===----------------------------------------------------------------------===//
1952
1953void transform::GetTypeOp::getEffects(
1955 onlyReadsHandle(getValueMutable(), effects);
1956 producesHandle(getOperation()->getOpResults(), effects);
1957 onlyReadsPayload(effects);
1958}
1959
1961transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1965 for (Value value : state.getPayloadValues(getValue())) {
1966 Type type = value.getType();
1967 if (getElemental()) {
1968 if (auto shaped = dyn_cast<ShapedType>(type)) {
1969 type = shaped.getElementType();
1970 }
1971 }
1972 params.push_back(TypeAttr::get(type));
1973 }
1974 results.setParams(cast<OpResult>(getResult()), params);
1976}
1977
1978//===----------------------------------------------------------------------===//
1979// IncludeOp
1980//===----------------------------------------------------------------------===//
1981
1982/// Applies the transform ops contained in `block`. Maps `results` to the same
1983/// values as the operands of the block terminator.
1985applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1987 transform::TransformResults &results) {
1988 // Apply the sequenced ops one by one.
1989 for (Operation &transform : block.without_terminator()) {
1991 state.applyTransform(cast<transform::TransformOpInterface>(transform));
1992 if (result.isDefiniteFailure())
1993 return result;
1994
1995 if (result.isSilenceableFailure()) {
1996 if (mode == transform::FailurePropagationMode::Propagate) {
1997 // Propagate empty results in case of early exit.
1998 forwardEmptyOperands(&block, state, results);
1999 return result;
2000 }
2001 (void)result.silence();
2002 }
2003 }
2004
2005 // Forward the operation mapping for values yielded from the sequence to the
2006 // values produced by the sequence op.
2007 transform::detail::forwardTerminatorOperands(&block, state, results);
2009}
2010
2012transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2016 getOperation(), getTarget());
2017 assert(callee && "unverified reference to unknown symbol");
2018
2019 if (callee.isExternal())
2020 return emitDefiniteFailure() << "unresolved external named sequence";
2021
2022 // Map operands to block arguments.
2024 detail::prepareValueMappings(mappings, getOperands(), state);
2025 auto scope = state.make_region_scope(callee.getBody());
2026 for (auto &&[arg, map] :
2027 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2028 if (failed(state.mapBlockArgument(arg, map)))
2030 }
2031
2033 callee.getBody().front(), getFailurePropagationMode(), state, results);
2034
2035 if (!result.succeeded())
2036 return result;
2037
2038 mappings.clear();
2039 detail::prepareValueMappings(
2040 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2041 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2042 results.setMappedValues(result, mapping);
2043 return result;
2044}
2045
2047verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2048
2049void transform::IncludeOp::getEffects(
2051 // Always mark as modifying the payload.
2052 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2053 // only read, the payload may still be modified, so we currently stay on the
2054 // conservative side and always indicate modification. This may prevent some
2055 // code reordering.
2056 modifiesPayload(effects);
2057
2058 // Results are always produced.
2059 producesHandle(getOperation()->getOpResults(), effects);
2060
2061 // Adds default effects to operands and results. This will be added if
2062 // preconditions fail so the trait verifier doesn't complain about missing
2063 // effects and the real precondition failure is reported later on.
2064 auto defaultEffects = [&] {
2065 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2066 };
2067
2068 // Bail if the callee is unknown. This may run as part of the verification
2069 // process before we verified the validity of the callee or of this op.
2070 auto target =
2071 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2072 if (!target)
2073 return defaultEffects();
2075 getOperation(), getTarget());
2076 if (!callee)
2077 return defaultEffects();
2078
2079 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2080 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2081 consumesHandle(getOperation()->getOpOperand(i), effects);
2082 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2083 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2084 }
2085}
2086
2087LogicalResult
2088transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2089 // Access through indirection and do additional checking because this may be
2090 // running before the main op verifier.
2091 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2092 if (!targetAttr)
2093 return emitOpError() << "expects a 'target' symbol reference attribute";
2094
2095 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2096 *this, targetAttr);
2097 if (!target)
2098 return emitOpError() << "does not reference a named transform sequence";
2099
2100 FunctionType fnType = target.getFunctionType();
2101 if (fnType.getNumInputs() != getNumOperands())
2102 return emitError("incorrect number of operands for callee");
2103
2104 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2105 if (getOperand(i).getType() != fnType.getInput(i)) {
2106 return emitOpError("operand type mismatch: expected operand type ")
2107 << fnType.getInput(i) << ", but provided "
2108 << getOperand(i).getType() << " for operand number " << i;
2109 }
2110 }
2111
2112 if (fnType.getNumResults() != getNumResults())
2113 return emitError("incorrect number of results for callee");
2114
2115 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2116 Type resultType = getResult(i).getType();
2117 Type funcType = fnType.getResult(i);
2118 if (!implementSameTransformInterface(resultType, funcType)) {
2119 return emitOpError() << "type of result #" << i
2120 << " must implement the same transform dialect "
2121 "interface as the corresponding callee result";
2122 }
2123 }
2124
2126 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2127 /*alsoVerifyInternal=*/true)
2128 .checkAndReport();
2129}
2130
2131//===----------------------------------------------------------------------===//
2132// MatchOperationEmptyOp
2133//===----------------------------------------------------------------------===//
2134
2135DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2136 ::std::optional<::mlir::Operation *> maybeCurrent,
2138 if (!maybeCurrent.has_value()) {
2139 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2141 }
2142 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2143 return emitSilenceableError() << "operation is not empty";
2144}
2145
2146//===----------------------------------------------------------------------===//
2147// MatchOperationNameOp
2148//===----------------------------------------------------------------------===//
2149
2150DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2151 Operation *current, transform::TransformResults &results,
2153 StringRef currentOpName = current->getName().getStringRef();
2154 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2155 if (acceptedAttr.getValue() == currentOpName)
2157 }
2158 return emitSilenceableError() << "wrong operation name";
2159}
2160
2161//===----------------------------------------------------------------------===//
2162// MatchParamCmpIOp
2163//===----------------------------------------------------------------------===//
2164
2166transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2169 auto signedAPIntAsString = [&](const APInt &value) {
2170 std::string str;
2171 llvm::raw_string_ostream os(str);
2172 value.print(os, /*isSigned=*/true);
2173 return str;
2174 };
2175
2176 ArrayRef<Attribute> params = state.getParams(getParam());
2177 ArrayRef<Attribute> references = state.getParams(getReference());
2178
2179 if (params.size() != references.size()) {
2180 return emitSilenceableError()
2181 << "parameters have different payload lengths (" << params.size()
2182 << " vs " << references.size() << ")";
2183 }
2184
2185 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2186 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2187 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2188 if (!intAttr || !refAttr) {
2189 return emitDefiniteFailure()
2190 << "non-integer parameter value not expected";
2191 }
2192 if (intAttr.getType() != refAttr.getType()) {
2193 return emitDefiniteFailure()
2194 << "mismatching integer attribute types in parameter #" << i;
2195 }
2196 APInt value = intAttr.getValue();
2197 APInt refValue = refAttr.getValue();
2198
2199 // TODO: this copy will not be necessary in C++20.
2200 int64_t position = i;
2201 auto reportError = [&](StringRef direction) {
2203 emitSilenceableError() << "expected parameter to be " << direction
2204 << " " << signedAPIntAsString(refValue)
2205 << ", got " << signedAPIntAsString(value);
2206 diag.attachNote(getParam().getLoc())
2207 << "value # " << position
2208 << " associated with the parameter defined here";
2209 return diag;
2210 };
2211
2212 switch (getPredicate()) {
2213 case MatchCmpIPredicate::eq:
2214 if (value.eq(refValue))
2215 break;
2216 return reportError("equal to");
2217 case MatchCmpIPredicate::ne:
2218 if (value.ne(refValue))
2219 break;
2220 return reportError("not equal to");
2221 case MatchCmpIPredicate::lt:
2222 if (value.slt(refValue))
2223 break;
2224 return reportError("less than");
2225 case MatchCmpIPredicate::le:
2226 if (value.sle(refValue))
2227 break;
2228 return reportError("less than or equal to");
2229 case MatchCmpIPredicate::gt:
2230 if (value.sgt(refValue))
2231 break;
2232 return reportError("greater than");
2233 case MatchCmpIPredicate::ge:
2234 if (value.sge(refValue))
2235 break;
2236 return reportError("greater than or equal to");
2237 }
2238 }
2240}
2241
2242void transform::MatchParamCmpIOp::getEffects(
2244 onlyReadsHandle(getParamMutable(), effects);
2245 onlyReadsHandle(getReferenceMutable(), effects);
2246}
2247
2248//===----------------------------------------------------------------------===//
2249// ParamConstantOp
2250//===----------------------------------------------------------------------===//
2251
2253transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2256 results.setParams(cast<OpResult>(getParam()), {getValue()});
2258}
2259
2260//===----------------------------------------------------------------------===//
2261// MergeHandlesOp
2262//===----------------------------------------------------------------------===//
2263
2265transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2268 ValueRange handles = getHandles();
2269 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2270 SmallVector<Operation *> operations;
2271 for (Value operand : handles)
2272 llvm::append_range(operations, state.getPayloadOps(operand));
2273 if (!getDeduplicate()) {
2274 results.set(llvm::cast<OpResult>(getResult()), operations);
2276 }
2277
2278 SetVector<Operation *> uniqued(llvm::from_range, operations);
2279 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2281 }
2282
2283 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2285 for (Value attribute : handles)
2286 llvm::append_range(attrs, state.getParams(attribute));
2287 if (!getDeduplicate()) {
2288 results.setParams(cast<OpResult>(getResult()), attrs);
2290 }
2291
2292 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2293 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2295 }
2296
2297 assert(
2298 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2299 "expected value handle type");
2300 SmallVector<Value> payloadValues;
2301 for (Value value : handles)
2302 llvm::append_range(payloadValues, state.getPayloadValues(value));
2303 if (!getDeduplicate()) {
2304 results.setValues(cast<OpResult>(getResult()), payloadValues);
2306 }
2307
2308 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2309 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2311}
2312
2313bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2314 // Handles may be the same if deduplicating is enabled.
2315 return getDeduplicate();
2316}
2317
2318void transform::MergeHandlesOp::getEffects(
2320 onlyReadsHandle(getHandlesMutable(), effects);
2321 producesHandle(getOperation()->getOpResults(), effects);
2322
2323 // There are no effects on the Payload IR as this is only a handle
2324 // manipulation.
2325}
2326
2327OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2328 if (getDeduplicate() || getHandles().size() != 1)
2329 return {};
2330
2331 // If deduplication is not required and there is only one operand, it can be
2332 // used directly instead of merging.
2333 return getHandles().front();
2334}
2335
2336//===----------------------------------------------------------------------===//
2337// NamedSequenceOp
2338//===----------------------------------------------------------------------===//
2339
2341transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2344 if (isExternal())
2345 return emitDefiniteFailure() << "unresolved external named sequence";
2346
2347 // Map the entry block argument to the list of operations.
2348 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2349 // without attaching the interface / trait since that is tailored to a
2350 // dangling top-level op that does not get "called".
2351 auto scope = state.make_region_scope(getBody());
2352 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2353 state, this->getOperation(), getBody())))
2355
2356 return applySequenceBlock(getBody().front(),
2357 FailurePropagationMode::Propagate, state, results);
2358}
2359
2360void transform::NamedSequenceOp::getEffects(
2362
2363ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2366 parser, result, /*allowVariadic=*/false,
2367 getFunctionTypeAttrName(result.name),
2368 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2370 std::string &) { return builder.getFunctionType(inputs, results); },
2371 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2372}
2373
2374void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2376 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2377 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2378 getResAttrsAttrName());
2379}
2380
2381/// Verifies that a symbol function-like transform dialect operation has the
2382/// signature and the terminator that have conforming types, i.e., types
2383/// implementing the same transform dialect type interface. If `allowExternal`
2384/// is set, allow external symbols (declarations) and don't check the terminator
2385/// as it may not exist.
2387verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2388 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2391 << "cannot be defined inside another transform op";
2392 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2393 return diag;
2394 }
2395
2396 if (op.isExternal() || op.getFunctionBody().empty()) {
2397 if (allowExternal)
2399
2400 return emitSilenceableFailure(op) << "cannot be external";
2401 }
2402
2403 if (op.getFunctionBody().front().empty())
2404 return emitSilenceableFailure(op) << "expected a non-empty body block";
2405
2406 Operation *terminator = &op.getFunctionBody().front().back();
2407 if (!isa<transform::YieldOp>(terminator)) {
2409 << "expected '"
2410 << transform::YieldOp::getOperationName()
2411 << "' as terminator";
2412 diag.attachNote(terminator->getLoc()) << "terminator";
2413 return diag;
2414 }
2415
2416 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2417 return emitSilenceableFailure(terminator)
2418 << "expected terminator to have as many operands as the parent op "
2419 "has results";
2420 }
2421 for (auto [i, operandType, resultType] : llvm::zip_equal(
2422 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2423 terminator->getOperands().getType(), op.getResultTypes())) {
2424 if (operandType == resultType)
2425 continue;
2426 return emitSilenceableFailure(terminator)
2427 << "the type of the terminator operand #" << i
2428 << " must match the type of the corresponding parent op result ("
2429 << operandType << " vs " << resultType << ")";
2430 }
2431
2433}
2434
2435/// Verification of a NamedSequenceOp. This does not report the error
2436/// immediately, so it can be used to check for op's well-formedness before the
2437/// verifier runs, e.g., during trait verification.
2439verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2440 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2441 if (!parent->getAttr(
2442 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2445 << "expects the parent symbol table to have the '"
2446 << transform::TransformDialect::kWithNamedSequenceAttrName
2447 << "' attribute";
2448 diag.attachNote(parent->getLoc()) << "symbol table operation";
2449 return diag;
2450 }
2451 }
2452
2453 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2456 << "cannot be defined inside another transform op";
2457 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2458 return diag;
2459 }
2460
2461 if (op.isExternal() || op.getBody().empty())
2462 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2463 emitWarnings);
2464
2465 if (op.getBody().front().empty())
2466 return emitSilenceableFailure(op) << "expected a non-empty body block";
2467
2468 // Check that all operations in the body implement TransformOpInterface
2469 for (Operation &child : op.getBody().front().without_terminator()) {
2470 if (!isa<transform::TransformOpInterface>(child)) {
2473 << "expected children ops to implement TransformOpInterface";
2474 diag.attachNote(child.getLoc()) << "op without interface";
2475 return diag;
2476 }
2477 }
2478
2479 Operation *terminator = &op.getBody().front().back();
2480 if (!isa<transform::YieldOp>(terminator)) {
2482 << "expected '"
2483 << transform::YieldOp::getOperationName()
2484 << "' as terminator";
2485 diag.attachNote(terminator->getLoc()) << "terminator";
2486 return diag;
2487 }
2488
2489 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2490 return emitSilenceableFailure(terminator)
2491 << "expected terminator to have as many operands as the parent op "
2492 "has results";
2493 }
2494 for (auto [i, operandType, resultType] :
2495 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2496 terminator->getOperands().getType(),
2497 op.getFunctionType().getResults())) {
2498 if (operandType == resultType)
2499 continue;
2500 return emitSilenceableFailure(terminator)
2501 << "the type of the terminator operand #" << i
2502 << " must match the type of the corresponding parent op result ("
2503 << operandType << " vs " << resultType << ")";
2504 }
2505
2506 auto funcOp = cast<FunctionOpInterface>(*op);
2508 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2509 if (!diag.succeeded())
2510 return diag;
2511
2512 return verifyYieldingSingleBlockOp(funcOp,
2513 /*allowExternal=*/true);
2514}
2515
2516LogicalResult transform::NamedSequenceOp::verify() {
2517 // Actual verification happens in a separate function for reusability.
2518 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2519}
2520
2521template <typename FnTy>
2522static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2523 Type bbArgType, TypeRange extraBindingTypes,
2524 FnTy bodyBuilder) {
2525 SmallVector<Type> types;
2526 types.reserve(1 + extraBindingTypes.size());
2527 types.push_back(bbArgType);
2528 llvm::append_range(types, extraBindingTypes);
2529
2530 OpBuilder::InsertionGuard guard(builder);
2531 Region *region = state.regions.back().get();
2532 Block *bodyBlock =
2533 builder.createBlock(region, region->begin(), types,
2534 SmallVector<Location>(types.size(), state.location));
2535
2536 // Populate body.
2537 builder.setInsertionPointToStart(bodyBlock);
2538 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2539 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2540 } else {
2541 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2542 bodyBlock->getArguments().drop_front());
2543 }
2544}
2545
2546void transform::NamedSequenceOp::build(OpBuilder &builder,
2547 OperationState &state, StringRef symName,
2548 Type rootType, TypeRange resultTypes,
2549 SequenceBodyBuilderFn bodyBuilder,
2551 ArrayRef<DictionaryAttr> argAttrs) {
2553 builder.getStringAttr(symName));
2554 state.addAttribute(getFunctionTypeAttrName(state.name),
2555 TypeAttr::get(FunctionType::get(builder.getContext(),
2556 rootType, resultTypes)));
2557 state.attributes.append(attrs.begin(), attrs.end());
2558 state.addRegion();
2559
2560 buildSequenceBody(builder, state, rootType,
2561 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2562}
2563
2564//===----------------------------------------------------------------------===//
2565// NumAssociationsOp
2566//===----------------------------------------------------------------------===//
2567
2569transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2572 size_t numAssociations =
2574 .Case([&](TransformHandleTypeInterface opHandle) {
2575 return llvm::range_size(state.getPayloadOps(getHandle()));
2576 })
2577 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2578 return llvm::range_size(state.getPayloadValues(getHandle()));
2579 })
2580 .Case([&](TransformParamTypeInterface param) {
2581 return llvm::range_size(state.getParams(getHandle()));
2582 })
2583 .DefaultUnreachable("unknown kind of transform dialect type");
2584 results.setParams(cast<OpResult>(getNum()),
2585 rewriter.getI64IntegerAttr(numAssociations));
2587}
2588
2589LogicalResult transform::NumAssociationsOp::verify() {
2590 // Verify that the result type accepts an i64 attribute as payload.
2591 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2592 return resultType
2593 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2594 .checkAndReport();
2595}
2596
2597//===----------------------------------------------------------------------===//
2598// SelectOp
2599//===----------------------------------------------------------------------===//
2600
2602transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2606 auto payloadOps = state.getPayloadOps(getTarget());
2607 for (Operation *op : payloadOps) {
2608 if (op->getName().getStringRef() == getOpName())
2609 result.push_back(op);
2610 }
2611 results.set(cast<OpResult>(getResult()), result);
2613}
2614
2615//===----------------------------------------------------------------------===//
2616// SplitHandleOp
2617//===----------------------------------------------------------------------===//
2618
2619void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2620 Value target, int64_t numResultHandles) {
2621 result.addOperands(target);
2622 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2623}
2624
2626transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2629 int64_t numPayloads =
2631 .Case([&](TransformHandleTypeInterface x) {
2632 return llvm::range_size(state.getPayloadOps(getHandle()));
2633 })
2634 .Case([&](TransformValueHandleTypeInterface x) {
2635 return llvm::range_size(state.getPayloadValues(getHandle()));
2636 })
2637 .Case([&](TransformParamTypeInterface x) {
2638 return llvm::range_size(state.getParams(getHandle()));
2639 })
2640 .DefaultUnreachable("unknown transform dialect type interface");
2641
2642 auto produceNumOpsError = [&]() {
2643 return emitSilenceableError()
2644 << getHandle() << " expected to contain " << this->getNumResults()
2645 << " payloads but it contains " << numPayloads << " payloads";
2646 };
2647
2648 // Fail if there are more payload ops than results and no overflow result was
2649 // specified.
2650 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2651 return produceNumOpsError();
2652
2653 // Fail if there are more results than payload ops. Unless:
2654 // - "fail_on_payload_too_small" is set to "false", or
2655 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2656 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2657 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2658 return produceNumOpsError();
2659
2660 // Distribute payloads.
2661 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2662 if (getOverflowResult())
2663 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2664
2665 auto container = [&]() {
2666 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2667 return llvm::map_to_vector(
2668 state.getPayloadOps(getHandle()),
2669 [](Operation *op) -> MappedValue { return op; });
2670 }
2671 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2672 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2673 [](Value v) -> MappedValue { return v; });
2674 }
2675 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2676 "unsupported kind of transform dialect type");
2677 return llvm::map_to_vector(state.getParams(getHandle()),
2678 [](Attribute a) -> MappedValue { return a; });
2679 }();
2680
2681 for (auto &&en : llvm::enumerate(container)) {
2682 int64_t resultNum = en.index();
2683 if (resultNum >= getNumResults())
2684 resultNum = *getOverflowResult();
2685 resultHandles[resultNum].push_back(en.value());
2686 }
2687
2688 // Set transform op results.
2689 for (auto &&it : llvm::enumerate(resultHandles))
2690 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2691 it.value());
2692
2694}
2695
2696void transform::SplitHandleOp::getEffects(
2698 onlyReadsHandle(getHandleMutable(), effects);
2699 producesHandle(getOperation()->getOpResults(), effects);
2700 // There are no effects on the Payload IR as this is only a handle
2701 // manipulation.
2702}
2703
2704LogicalResult transform::SplitHandleOp::verify() {
2705 if (getOverflowResult().has_value() &&
2706 !(*getOverflowResult() < getNumResults()))
2707 return emitOpError("overflow_result is not a valid result index");
2708
2709 for (Type resultType : getResultTypes()) {
2710 if (implementSameTransformInterface(getHandle().getType(), resultType))
2711 continue;
2712
2713 return emitOpError("expects result types to implement the same transform "
2714 "interface as the operand type");
2715 }
2716
2717 return success();
2718}
2719
2720//===----------------------------------------------------------------------===//
2721// PayloadOp
2722//===----------------------------------------------------------------------===//
2723
2724void transform::PayloadOp::getCheckedNormalForms(
2726 llvm::append_range(normalForms,
2727 getNormalForms().getAsRange<NormalFormAttrInterface>());
2728}
2729
2730//===----------------------------------------------------------------------===//
2731// ReplicateOp
2732//===----------------------------------------------------------------------===//
2733
2735transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2738 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2739 for (const auto &en : llvm::enumerate(getHandles())) {
2740 Value handle = en.value();
2741 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2742 SmallVector<Operation *> current =
2743 llvm::to_vector(state.getPayloadOps(handle));
2745 payload.reserve(numRepetitions * current.size());
2746 for (unsigned i = 0; i < numRepetitions; ++i)
2747 llvm::append_range(payload, current);
2748 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2749 } else {
2750 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2751 "expected param type");
2752 ArrayRef<Attribute> current = state.getParams(handle);
2754 params.reserve(numRepetitions * current.size());
2755 for (unsigned i = 0; i < numRepetitions; ++i)
2756 llvm::append_range(params, current);
2757 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2758 params);
2759 }
2760 }
2762}
2763
2764void transform::ReplicateOp::getEffects(
2766 onlyReadsHandle(getPatternMutable(), effects);
2767 onlyReadsHandle(getHandlesMutable(), effects);
2768 producesHandle(getOperation()->getOpResults(), effects);
2769}
2770
2771//===----------------------------------------------------------------------===//
2772// SequenceOp
2773//===----------------------------------------------------------------------===//
2774
2776transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2779 // Map the entry block argument to the list of operations.
2780 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2781 if (failed(mapBlockArguments(state)))
2783
2784 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2785 results);
2786}
2787
2788static ParseResult parseSequenceOpOperands(
2789 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2790 Type &rootType,
2792 SmallVectorImpl<Type> &extraBindingTypes) {
2794 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2795 if (!hasRoot.has_value()) {
2796 root = std::nullopt;
2797 return success();
2798 }
2799 if (failed(hasRoot.value()))
2800 return failure();
2801 root = rootOperand;
2802
2803 if (succeeded(parser.parseOptionalComma())) {
2804 if (failed(parser.parseOperandList(extraBindings)))
2805 return failure();
2806 }
2807 if (failed(parser.parseColon()))
2808 return failure();
2809
2810 // The paren is truly optional.
2811 (void)parser.parseOptionalLParen();
2812
2813 if (failed(parser.parseType(rootType))) {
2814 return failure();
2815 }
2816
2817 if (!extraBindings.empty()) {
2818 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2819 return failure();
2820 }
2821
2822 if (extraBindingTypes.size() != extraBindings.size()) {
2823 return parser.emitError(parser.getNameLoc(),
2824 "expected types to be provided for all operands");
2825 }
2826
2827 // The paren is truly optional.
2828 (void)parser.parseOptionalRParen();
2829 return success();
2830}
2831
2833 Value root, Type rootType,
2834 ValueRange extraBindings,
2835 TypeRange extraBindingTypes) {
2836 if (!root)
2837 return;
2838
2839 printer << root;
2840 bool hasExtras = !extraBindings.empty();
2841 if (hasExtras) {
2842 printer << ", ";
2843 printer.printOperands(extraBindings);
2844 }
2845
2846 printer << " : ";
2847 if (hasExtras)
2848 printer << "(";
2849
2850 printer << rootType;
2851 if (hasExtras)
2852 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2853}
2854
2855/// Returns `true` if the given op operand may be consuming the handle value in
2856/// the Transform IR. That is, if it may have a Free effect on it.
2858 // Conservatively assume the effect being present in absence of the interface.
2859 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2860 if (!iface)
2861 return true;
2862
2863 return isHandleConsumed(use.get(), iface);
2864}
2865
2866static LogicalResult
2868 function_ref<InFlightDiagnostic()> reportError) {
2869 OpOperand *potentialConsumer = nullptr;
2870 for (OpOperand &use : value.getUses()) {
2872 continue;
2873
2874 if (!potentialConsumer) {
2875 potentialConsumer = &use;
2876 continue;
2877 }
2878
2879 InFlightDiagnostic diag = reportError()
2880 << " has more than one potential consumer";
2881 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2882 << "used here as operand #" << potentialConsumer->getOperandNumber();
2883 diag.attachNote(use.getOwner()->getLoc())
2884 << "used here as operand #" << use.getOperandNumber();
2885 return diag;
2886 }
2887
2888 return success();
2889}
2890
2891LogicalResult transform::SequenceOp::verify() {
2892 assert(getBodyBlock()->getNumArguments() >= 1 &&
2893 "the number of arguments must have been verified to be more than 1 by "
2894 "PossibleTopLevelTransformOpTrait");
2895
2896 if (!getRoot() && !getExtraBindings().empty()) {
2897 return emitOpError()
2898 << "does not expect extra operands when used as top-level";
2899 }
2900
2901 // Check if a block argument has more than one consuming use.
2902 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2903 if (failed(checkDoubleConsume(arg, [this, arg]() {
2904 return (emitOpError() << "block argument #" << arg.getArgNumber());
2905 }))) {
2906 return failure();
2907 }
2908 }
2909
2910 // Check properties of the nested operations they cannot check themselves.
2911 for (Operation &child : *getBodyBlock()) {
2912 if (!isa<TransformOpInterface>(child) &&
2913 &child != &getBodyBlock()->back()) {
2915 emitOpError()
2916 << "expected children ops to implement TransformOpInterface";
2917 diag.attachNote(child.getLoc()) << "op without interface";
2918 return diag;
2919 }
2920
2921 for (OpResult result : child.getResults()) {
2922 auto report = [&]() {
2923 return (child.emitError() << "result #" << result.getResultNumber());
2924 };
2925 if (failed(checkDoubleConsume(result, report)))
2926 return failure();
2927 }
2928 }
2929
2930 if (!getBodyBlock()->mightHaveTerminator())
2931 return emitOpError() << "expects to have a terminator in the body";
2932
2933 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2934 getOperation()->getResultTypes()) {
2936 << "expects the types of the terminator operands "
2937 "to match the types of the result";
2938 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2939 return diag;
2940 }
2941 return success();
2942}
2943
2944void transform::SequenceOp::getEffects(
2947}
2948
2950transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2951 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2952 if (getOperation()->getNumOperands() > 0)
2953 return getOperation()->getOperands();
2954 return OperandRange(getOperation()->operand_end(),
2955 getOperation()->operand_end());
2956}
2957
2958void transform::SequenceOp::getSuccessorRegions(
2960 if (point.isParent()) {
2961 Region *bodyRegion = &getBody();
2962 regions.emplace_back(bodyRegion);
2963 return;
2964 }
2965
2966 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2967 &getBody() &&
2968 "unexpected region index");
2969 regions.push_back(RegionSuccessor::parent());
2970}
2971
2973transform::SequenceOp::getSuccessorInputs(RegionSuccessor successor) {
2974 if (getNumOperands() == 0)
2975 return ValueRange();
2976 if (successor.isParent())
2977 return getResults();
2978 return getBody().getArguments();
2979}
2980
2981void transform::SequenceOp::getRegionInvocationBounds(
2983 (void)operands;
2984 bounds.emplace_back(1, 1);
2985}
2986
2987void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2988 TypeRange resultTypes,
2989 FailurePropagationMode failurePropagationMode,
2990 Value root,
2991 SequenceBodyBuilderFn bodyBuilder) {
2992 build(builder, state, resultTypes, failurePropagationMode, root,
2993 /*extra_bindings=*/ValueRange());
2994 Type bbArgType = root.getType();
2995 buildSequenceBody(builder, state, bbArgType,
2996 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2997}
2998
2999void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3000 TypeRange resultTypes,
3001 FailurePropagationMode failurePropagationMode,
3002 Value root, ValueRange extraBindings,
3003 SequenceBodyBuilderArgsFn bodyBuilder) {
3004 build(builder, state, resultTypes, failurePropagationMode, root,
3005 extraBindings);
3006 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3007 bodyBuilder);
3008}
3009
3010void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3011 TypeRange resultTypes,
3012 FailurePropagationMode failurePropagationMode,
3013 Type bbArgType,
3014 SequenceBodyBuilderFn bodyBuilder) {
3015 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3016 /*extra_bindings=*/ValueRange());
3017 buildSequenceBody(builder, state, bbArgType,
3018 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3019}
3020
3021void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3022 TypeRange resultTypes,
3023 FailurePropagationMode failurePropagationMode,
3024 Type bbArgType, TypeRange extraBindingTypes,
3025 SequenceBodyBuilderArgsFn bodyBuilder) {
3026 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3027 /*extra_bindings=*/ValueRange());
3028 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3029}
3030
3031//===----------------------------------------------------------------------===//
3032// PrintOp
3033//===----------------------------------------------------------------------===//
3034
3035void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3036 StringRef name) {
3037 if (!name.empty())
3038 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3039}
3040
3041void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3042 Value target, StringRef name) {
3043 result.addOperands({target});
3044 build(builder, result, name);
3045}
3046
3048transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3051 llvm::outs() << "[[[ IR printer: ";
3052 if (getName().has_value())
3053 llvm::outs() << *getName() << " ";
3054
3055 OpPrintingFlags printFlags;
3056 if (getAssumeVerified().value_or(false))
3057 printFlags.assumeVerified();
3058 if (getUseLocalScope().value_or(false))
3059 printFlags.useLocalScope();
3060 if (getSkipRegions().value_or(false))
3061 printFlags.skipRegions();
3062
3063 if (!getTarget()) {
3064 llvm::outs() << "top-level ]]]\n";
3065 state.getTopLevel()->print(llvm::outs(), printFlags);
3066 llvm::outs() << "\n";
3067 llvm::outs().flush();
3069 }
3070
3071 llvm::outs() << "]]]\n";
3072 for (Operation *target : state.getPayloadOps(getTarget())) {
3073 target->print(llvm::outs(), printFlags);
3074 llvm::outs() << "\n";
3075 }
3076
3077 llvm::outs().flush();
3079}
3080
3081void transform::PrintOp::getEffects(
3083 // We don't really care about mutability here, but `getTarget` now
3084 // unconditionally casts to a specific type before verification could run
3085 // here.
3086 if (!getTargetMutable().empty())
3087 onlyReadsHandle(getTargetMutable()[0], effects);
3088 onlyReadsPayload(effects);
3089
3090 // There is no resource for stderr file descriptor, so just declare print
3091 // writes into the default resource.
3092 effects.emplace_back(MemoryEffects::Write::get());
3093}
3094
3095//===----------------------------------------------------------------------===//
3096// VerifyOp
3097//===----------------------------------------------------------------------===//
3098
3100transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3106 << "failed to verify payload op";
3107 diag.attachNote(target->getLoc()) << "payload op";
3108 return diag;
3109 }
3111}
3112
3113void transform::VerifyOp::getEffects(
3115 transform::onlyReadsHandle(getTargetMutable(), effects);
3116}
3117
3118//===----------------------------------------------------------------------===//
3119// YieldOp
3120//===----------------------------------------------------------------------===//
3121
3122void transform::YieldOp::getEffects(
3124 onlyReadsHandle(getOperandsMutable(), effects);
3125}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:143
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
GreedyRewriteConfig & setListener(RewriterBase::Listener *listener)
GreedyRewriteConfig & enableCSEBetweenIterations(bool enable=true)
GreedyRewriteConfig & setMaxIterations(int64_t iterations)
GreedyRewriteConfig & setMaxNumRewrites(int64_t limit)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1143
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
unsigned getNumOperands()
Definition Operation.h:372
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getOpResults()
Definition Operation.h:446
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr, int64_t *numCSE=nullptr, int64_t *numDCE=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:418
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool eliminateTriviallyDeadOps(RewriterBase &rewriter, Region &region, bool includeNestedRegions=true)
Remove trivially dead operations from region.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.