MLIR  19.0.0git
DecomposeCallGraphTypes.cpp
Go to the documentation of this file.
1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinOps.h"
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 //===----------------------------------------------------------------------===//
17 // ValueDecomposer
18 //===----------------------------------------------------------------------===//
19 
21  Type type, Value value,
22  SmallVectorImpl<Value> &results) {
23  for (auto &conversion : decomposeValueConversions)
24  if (conversion(builder, loc, type, value, results))
25  return;
26  results.push_back(value);
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // DecomposeCallGraphTypesOpConversionPattern
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// Base OpConversionPattern class to make a ValueDecomposer available to
35 /// inherited patterns.
36 template <typename SourceOp>
37 class DecomposeCallGraphTypesOpConversionPattern
38  : public OpConversionPattern<SourceOp> {
39 public:
40  DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
41  MLIRContext *context,
42  ValueDecomposer &decomposer,
43  PatternBenefit benefit = 1)
44  : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45  decomposer(decomposer) {}
46 
47 protected:
48  ValueDecomposer &decomposer;
49 };
50 } // namespace
51 
52 //===----------------------------------------------------------------------===//
53 // DecomposeCallGraphTypesForFuncArgs
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 /// Expand function arguments according to the provided TypeConverter and
58 /// ValueDecomposer.
59 struct DecomposeCallGraphTypesForFuncArgs
60  : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61  using DecomposeCallGraphTypesOpConversionPattern::
62  DecomposeCallGraphTypesOpConversionPattern;
63 
65  matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
66  ConversionPatternRewriter &rewriter) const final {
67  auto functionType = op.getFunctionType();
68 
69  // Convert function arguments using the provided TypeConverter.
70  TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
71  for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
72  SmallVector<Type, 2> decomposedTypes;
73  if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
74  return failure();
75  if (!decomposedTypes.empty())
76  conversion.addInputs(argType.index(), decomposedTypes);
77  }
78 
79  // If the SignatureConversion doesn't apply, bail out.
80  if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
81  &conversion)))
82  return failure();
83 
84  // Update the signature of the function.
85  SmallVector<Type, 2> newResultTypes;
86  if (failed(typeConverter->convertTypes(functionType.getResults(),
87  newResultTypes)))
88  return failure();
89  rewriter.modifyOpInPlace(op, [&] {
90  op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
91  newResultTypes));
92  });
93  return success();
94  }
95 };
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // DecomposeCallGraphTypesForReturnOp
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 /// Expand return operands according to the provided TypeConverter and
104 /// ValueDecomposer.
105 struct DecomposeCallGraphTypesForReturnOp
106  : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107  using DecomposeCallGraphTypesOpConversionPattern::
108  DecomposeCallGraphTypesOpConversionPattern;
110  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const final {
112  SmallVector<Value, 2> newOperands;
113  for (Value operand : adaptor.getOperands())
114  decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115  operand, newOperands);
116  rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
117  return success();
118  }
119 };
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // DecomposeCallGraphTypesForCallOp
124 //===----------------------------------------------------------------------===//
125 
126 namespace {
127 /// Expand call op operands and results according to the provided TypeConverter
128 /// and ValueDecomposer.
129 struct DecomposeCallGraphTypesForCallOp
130  : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131  using DecomposeCallGraphTypesOpConversionPattern::
132  DecomposeCallGraphTypesOpConversionPattern;
133 
135  matchAndRewrite(CallOp op, OpAdaptor adaptor,
136  ConversionPatternRewriter &rewriter) const final {
137 
138  // Create the operands list of the new `CallOp`.
139  SmallVector<Value, 2> newOperands;
140  for (Value operand : adaptor.getOperands())
141  decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142  operand, newOperands);
143 
144  // Create the new result types for the new `CallOp` and track the indices in
145  // the new call op's results that correspond to the old call op's results.
146  //
147  // expandedResultIndices[i] = "list of new result indices that old result i
148  // expanded to".
149  SmallVector<Type, 2> newResultTypes;
150  SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
151  for (Type resultType : op.getResultTypes()) {
152  unsigned oldSize = newResultTypes.size();
153  if (failed(typeConverter->convertType(resultType, newResultTypes)))
154  return failure();
155  auto &resultMapping = expandedResultIndices.emplace_back();
156  for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157  resultMapping.push_back(i);
158  }
159 
160  CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161  newResultTypes, newOperands);
162 
163  // Build a replacement value for each result to replace its uses. If a
164  // result has multiple mapping values, it needs to be materialized as a
165  // single value.
166  SmallVector<Value, 2> replacedValues;
167  replacedValues.reserve(op.getNumResults());
168  for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169  auto decomposedValues = llvm::to_vector<6>(
170  llvm::map_range(expandedResultIndices[i],
171  [&](unsigned i) { return newCallOp.getResult(i); }));
172  if (decomposedValues.empty()) {
173  // No replacement is required.
174  replacedValues.push_back(nullptr);
175  } else if (decomposedValues.size() == 1) {
176  replacedValues.push_back(decomposedValues.front());
177  } else {
178  // Materialize a single Value to replace the original Value.
179  Value materialized = getTypeConverter()->materializeArgumentConversion(
180  rewriter, op.getLoc(), op.getType(i), decomposedValues);
181  replacedValues.push_back(materialized);
182  }
183  }
184  rewriter.replaceOp(op, replacedValues);
185  return success();
186  }
187 };
188 } // namespace
189 
191  MLIRContext *context, TypeConverter &typeConverter,
192  ValueDecomposer &decomposer, RewritePatternSet &patterns) {
193  patterns
194  .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195  DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196  decomposer);
197 }
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class provides all of the information necessary to convert a type signature.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides a hook that expands one Value into multiple Value's, with a TypeConverter-inspire...
void decomposeValue(OpBuilder &, Location, Type, Value, SmallVectorImpl< Value > &)
This method tries to decompose a value of a certain type using provided decompose callback functions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context, TypeConverter &typeConverter, ValueDecomposer &decomposer, RewritePatternSet &patterns)
Populates the patterns needed to drive the conversion process for decomposing call graph types with t...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26