MLIR 23.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
18#include "mlir/Pass/Pass.h"
20
21using namespace mlir;
22
23namespace {
24
25/// Helper function for converting branch ops. This function converts the
26/// signature of the given block. If the new block signature is different from
27/// `expectedTypes`, returns "failure".
28static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
29 const TypeConverter *converter,
30 Operation *branchOp, Block *block,
31 TypeRange expectedTypes) {
32 assert(converter && "expected non-null type converter");
33 assert(!block->isEntryBlock() && "entry blocks have no predecessors");
34
35 // There is nothing to do if the types already match.
36 if (block->getArgumentTypes() == expectedTypes)
37 return block;
38
39 // Compute the new block argument types and convert the block.
40 std::optional<TypeConverter::SignatureConversion> conversion =
41 converter->convertBlockSignature(block);
42 if (!conversion)
43 return rewriter.notifyMatchFailure(branchOp,
44 "could not compute block signature");
45 if (expectedTypes != conversion->getConvertedTypes())
46 return rewriter.notifyMatchFailure(
47 branchOp,
48 "mismatch between adaptor operand types and computed block signature");
49 return rewriter.applySignatureConversion(block, *conversion, converter);
50}
51
52/// Flatten the given value ranges into a single vector of values.
55 for (const ValueRange &vals : values)
56 llvm::append_range(result, vals);
57 return result;
58}
59
60/// Convert the destination block signature (if necessary) and change the
61/// operands of the branch op.
62struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
63 using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
64
65 LogicalResult
66 matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
67 ConversionPatternRewriter &rewriter) const override {
68 SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
69 FailureOr<Block *> convertedBlock =
70 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
71 TypeRange(ValueRange(flattenedAdaptor)));
72 if (failed(convertedBlock))
73 return failure();
74 rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
75 *convertedBlock);
76 return success();
77 }
78};
79
80/// Convert the destination block signatures (if necessary) and change the
81/// operands of the branch op.
82struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
83 using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
84
85 LogicalResult
86 matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter) const override {
88 SmallVector<Value> flattenedAdaptorTrue =
89 flattenValues(adaptor.getTrueDestOperands());
90 SmallVector<Value> flattenedAdaptorFalse =
91 flattenValues(adaptor.getFalseDestOperands());
92 if (!llvm::hasSingleElement(adaptor.getCondition()))
93 return rewriter.notifyMatchFailure(op,
94 "expected single element condition");
95 FailureOr<Block *> convertedTrueBlock =
96 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
97 TypeRange(ValueRange(flattenedAdaptorTrue)));
98 if (failed(convertedTrueBlock))
99 return failure();
100 FailureOr<Block *> convertedFalseBlock =
101 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
102 TypeRange(ValueRange(flattenedAdaptorFalse)));
103 if (failed(convertedFalseBlock))
104 return failure();
105 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
106 op, llvm::getSingleElement(adaptor.getCondition()),
107 flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
108 *convertedTrueBlock, *convertedFalseBlock);
109 return success();
110 }
111};
112
113/// Convert the destination block signatures (if necessary) and change the
114/// operands of the switch op.
115struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
116 using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
117
118 LogicalResult
119 matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override {
121 // Get or convert default block.
122 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
123 rewriter, getTypeConverter(), op, op.getDefaultDestination(),
124 TypeRange(adaptor.getDefaultOperands()));
125 if (failed(convertedDefaultBlock))
126 return failure();
127
128 // Get or convert all case blocks.
129 SmallVector<Block *> caseDestinations;
130 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
131 for (auto it : llvm::enumerate(op.getCaseDestinations())) {
132 Block *b = it.value();
133 FailureOr<Block *> convertedBlock =
134 getConvertedBlock(rewriter, getTypeConverter(), op, b,
135 TypeRange(caseOperands[it.index()]));
136 if (failed(convertedBlock))
137 return failure();
138 caseDestinations.push_back(*convertedBlock);
139 }
140
141 rewriter.replaceOpWithNewOp<cf::SwitchOp>(
142 op, adaptor.getFlag(), *convertedDefaultBlock,
143 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
144 caseDestinations, caseOperands);
145 return success();
146 }
147};
148
149} // namespace
150
152 const TypeConverter &typeConverter, RewritePatternSet &patterns,
153 PatternBenefit benefit) {
154 patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
155 typeConverter, patterns.getContext(), benefit);
156}
157
159 const TypeConverter &typeConverter, ConversionTarget &target) {
160 target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
161 [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
162}
163
return success()
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
void populateCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of CF operations based on the provided type conver...
void populateCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateCFStructuralTypeConversionsAndLegality but does not populate the conversion target...
void populateCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for CF structural type conversions and sets up the provided ConversionTarget with ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.