MLIR  20.0.0git
DecomposeCallGraphTypes.cpp
Go to the documentation of this file.
1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinOps.h"
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 //===----------------------------------------------------------------------===//
17 // Helper functions
18 //===----------------------------------------------------------------------===//
19 
20 /// If the given value can be decomposed with the type converter, decompose it.
21 /// Otherwise, return the given value.
22 // TODO: Value decomposition should happen automatically through a 1:N adaptor.
23 // This function will disappear when the 1:1 and 1:N drivers are merged.
25  Value value,
26  const TypeConverter *converter) {
27  // Try to convert the given value's type. If that fails, just return the
28  // given value.
29  SmallVector<Type> convertedTypes;
30  if (failed(converter->convertType(value.getType(), convertedTypes)))
31  return {value};
32  if (convertedTypes.empty())
33  return {};
34 
35  // If the given value's type is already legal, just return the given value.
36  TypeRange convertedTypeRange(convertedTypes);
37  if (convertedTypeRange == TypeRange(value.getType()))
38  return {value};
39 
40  // Try to materialize a target conversion. If the materialization did not
41  // produce values of the requested type, the materialization failed. Just
42  // return the given value in that case.
44  builder, loc, convertedTypeRange, value);
45  if (result.empty())
46  return {value};
47  return result;
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // DecomposeCallGraphTypesForFuncArgs
52 //===----------------------------------------------------------------------===//
53 
54 namespace {
55 /// Expand function arguments according to the provided TypeConverter.
56 struct DecomposeCallGraphTypesForFuncArgs
57  : public OpConversionPattern<func::FuncOp> {
59 
60  LogicalResult
61  matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
62  ConversionPatternRewriter &rewriter) const final {
63  auto functionType = op.getFunctionType();
64 
65  // Convert function arguments using the provided TypeConverter.
66  TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
67  for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
68  SmallVector<Type, 2> decomposedTypes;
69  if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
70  return failure();
71  if (!decomposedTypes.empty())
72  conversion.addInputs(argType.index(), decomposedTypes);
73  }
74 
75  // If the SignatureConversion doesn't apply, bail out.
76  if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
77  &conversion)))
78  return failure();
79 
80  // Update the signature of the function.
81  SmallVector<Type, 2> newResultTypes;
82  if (failed(typeConverter->convertTypes(functionType.getResults(),
83  newResultTypes)))
84  return failure();
85  rewriter.modifyOpInPlace(op, [&] {
86  op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
87  newResultTypes));
88  });
89  return success();
90  }
91 };
92 } // namespace
93 
94 //===----------------------------------------------------------------------===//
95 // DecomposeCallGraphTypesForReturnOp
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Expand return operands according to the provided TypeConverter.
100 struct DecomposeCallGraphTypesForReturnOp
101  : public OpConversionPattern<ReturnOp> {
103 
104  LogicalResult
105  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
106  ConversionPatternRewriter &rewriter) const final {
107  SmallVector<Value, 2> newOperands;
108  for (Value operand : adaptor.getOperands()) {
109  // TODO: We can directly take the values from the adaptor once this is a
110  // 1:N conversion pattern.
111  llvm::append_range(newOperands,
112  decomposeValue(rewriter, operand.getLoc(), operand,
113  getTypeConverter()));
114  }
115  rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
116  return success();
117  }
118 };
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // DecomposeCallGraphTypesForCallOp
123 //===----------------------------------------------------------------------===//
124 
125 namespace {
126 /// Expand call op operands and results according to the provided TypeConverter.
127 struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
129 
130  LogicalResult
131  matchAndRewrite(CallOp op, OpAdaptor adaptor,
132  ConversionPatternRewriter &rewriter) const final {
133 
134  // Create the operands list of the new `CallOp`.
135  SmallVector<Value, 2> newOperands;
136  for (Value operand : adaptor.getOperands()) {
137  // TODO: We can directly take the values from the adaptor once this is a
138  // 1:N conversion pattern.
139  llvm::append_range(newOperands,
140  decomposeValue(rewriter, operand.getLoc(), operand,
141  getTypeConverter()));
142  }
143 
144  // Create the new result types for the new `CallOp` and track the number of
145  // replacement types for each original op result.
146  SmallVector<Type, 2> newResultTypes;
147  SmallVector<unsigned> expandedResultSizes;
148  for (Type resultType : op.getResultTypes()) {
149  unsigned oldSize = newResultTypes.size();
150  if (failed(typeConverter->convertType(resultType, newResultTypes)))
151  return failure();
152  expandedResultSizes.push_back(newResultTypes.size() - oldSize);
153  }
154 
155  CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
156  newResultTypes, newOperands);
157 
158  // Build a replacement value for each result to replace its uses.
159  SmallVector<ValueRange> replacedValues;
160  replacedValues.reserve(op.getNumResults());
161  unsigned startIdx = 0;
162  for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
163  ValueRange repl =
164  newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
165  replacedValues.push_back(repl);
166  startIdx += expandedResultSizes[i];
167  }
168  rewriter.replaceOpWithMultiple(op, replacedValues);
169  return success();
170  }
171 };
172 } // namespace
173 
175  MLIRContext *context, const TypeConverter &typeConverter,
176  RewritePatternSet &patterns) {
177  patterns
178  .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
179  DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
180 }
static SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value value, const TypeConverter *converter)
If the given value can be decomposed with the type converter, decompose it.
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class provides all of the information necessary to convert a type signature.
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context, const TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the patterns needed to drive the conversion process for decomposing call graph types with t...