MLIR  18.0.0git
OneToNTypeConversion.cpp
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "llvm/ADT/SmallSet.h"
13 
14 #include <unordered_map>
15 
16 using namespace llvm;
17 using namespace mlir;
18 
19 std::optional<SmallVector<Value>>
20 OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
21  Location loc,
22  TypeRange resultTypes,
23  Value input) const {
24  for (const OneToNMaterializationCallbackFn &fn :
25  llvm::reverse(oneToNTargetMaterializations)) {
26  if (std::optional<SmallVector<Value>> result =
27  fn(builder, resultTypes, input, loc))
28  return *result;
29  }
30  return std::nullopt;
31 }
32 
33 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
34  TypeRange convertedTypes = getConvertedTypes();
35  if (auto mapping = getInputMapping(originalTypeNo))
36  return convertedTypes.slice(mapping->inputNo, mapping->size);
37  return {};
38 }
39 
41 OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
42  unsigned originalValueNo) const {
43  if (auto mapping = getInputMapping(originalValueNo))
44  return convertedValues.slice(mapping->inputNo, mapping->size);
45  return {};
46 }
47 
48 void OneToNTypeMapping::convertLocation(
49  Value originalValue, unsigned originalValueNo,
50  llvm::SmallVectorImpl<Location> &result) const {
51  if (auto mapping = getInputMapping(originalValueNo))
52  result.append(mapping->size, originalValue.getLoc());
53 }
54 
55 void OneToNTypeMapping::convertLocations(
56  ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
57  assert(originalValues.size() == getOriginalTypes().size());
58  for (auto [i, value] : llvm::enumerate(originalValues))
59  convertLocation(value, i, result);
60 }
61 
62 static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
63  return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
64 }
65 
66 bool OneToNTypeMapping::hasNonIdentityConversion() const {
67  // XXX: I think that the original types and the converted types are the same
68  // iff there was no non-identity type conversion. If that is true, the
69  // patterns could actually test whether there is anything useful to do
70  // without having access to the signature conversion.
71  for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
72  TypeRange types = getConvertedTypes(i);
73  if (!isIdentityConversion(originalType, types)) {
74  assert(TypeRange(originalTypes) != getConvertedTypes());
75  return true;
76  }
77  }
78  assert(TypeRange(originalTypes) == getConvertedTypes());
79  return false;
80 }
81 
82 namespace {
83 enum class CastKind {
84  // Casts block arguments in the target type back to the source type. (If
85  // necessary, this cast becomes an argument materialization.)
86  Argument,
87 
88  // Casts other values in the target type back to the source type. (If
89  // necessary, this cast becomes a source materialization.)
90  Source,
91 
92  // Casts values in the source type to the target type. (If necessary, this
93  // cast becomes a target materialization.)
94  Target
95 };
96 }
97 
98 /// Mapping of enum values to string values.
99 StringRef getCastKindName(CastKind kind) {
100  static const std::unordered_map<CastKind, StringRef> castKindNames = {
101  {CastKind::Argument, "argument"},
102  {CastKind::Source, "source"},
103  {CastKind::Target, "target"}};
104  return castKindNames.at(kind);
105 }
106 
107 /// Attribute name that is used to annotate inserted unrealized casts with their
108 /// kind (source, argument, or target).
109 static const char *const castKindAttrName =
110  "__one-to-n-type-conversion_cast-kind__";
111 
112 /// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
113 /// result types. Returns the result values of the cast.
114 static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
115  ValueRange inputs, CastKind kind) {
116  // Create cast.
117  Location loc = builder.getUnknownLoc();
118  if (!inputs.empty())
119  loc = inputs.front().getLoc();
120  auto castOp =
121  builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
122 
123  // Store cast kind as attribute.
124  auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
125  castOp->setAttr(castKindAttrName, kindAttr);
126 
127  return castOp->getResults();
128 }
129 
130 /// Builds one `UnrealizedConversionCastOp` for each of the given original
131 /// values using the respective target types given in the provided conversion
132 /// mapping and returns the results of these casts. If the conversion mapping of
133 /// a value maps a type to itself (i.e., is an identity conversion), then no
134 /// cast is inserted and the original value is returned instead.
135 /// Note that these unrealized casts are different from target materializations
136 /// in that they are *always* inserted, even if they immediately fold away, such
137 /// that patterns always see valid intermediate IR, whereas materializations are
138 /// only used in the places where the unrealized casts *don't* fold away.
139 static SmallVector<Value>
141  OneToNTypeMapping &conversion,
142  RewriterBase &rewriter, CastKind kind) {
143 
144  // Convert each operand one by one.
145  SmallVector<Value> convertedValues;
146  convertedValues.reserve(conversion.getConvertedTypes().size());
147  for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
148  TypeRange convertedTypes = conversion.getConvertedTypes(idx);
149 
150  // Identity conversion: keep operand as is.
151  if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
152  convertedValues.push_back(originalValue);
153  continue;
154  }
155 
156  // Non-identity conversion: materialize target types.
157  ValueRange castResult =
158  buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
159  convertedValues.append(castResult.begin(), castResult.end());
160  }
161 
162  return convertedValues;
163 }
164 
165 /// Builds one `UnrealizedConversionCastOp` for each sequence of the given
166 /// original values to one value of the type they originated from, i.e., a
167 /// "reverse" conversion from N converted values back to one value of the
168 /// original type, using the given (forward) type conversion. If a given value
169 /// was mapped to a value of the same type (i.e., the conversion in the mapping
170 /// is an identity conversion), then the "converted" value is returned without
171 /// cast.
172 /// Note that these unrealized casts are different from source materializations
173 /// in that they are *always* inserted, even if they immediately fold away, such
174 /// that patterns always see valid intermediate IR, whereas materializations are
175 /// only used in the places where the unrealized casts *don't* fold away.
176 static SmallVector<Value>
178  const OneToNTypeMapping &typeConversion,
179  RewriterBase &rewriter) {
180  assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
181 
182  // Create unrealized cast op for each converted result of the op.
183  SmallVector<Value> recastValues;
184  TypeRange originalTypes = typeConversion.getOriginalTypes();
185  recastValues.reserve(originalTypes.size());
186  auto convertedValueIt = convertedValues.begin();
187  for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
188  TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
189  size_t numConvertedValues = convertedTypes.size();
190  if (isIdentityConversion(originalType, convertedTypes)) {
191  // Identity conversion: take result as is.
192  recastValues.push_back(*convertedValueIt);
193  } else {
194  // Non-identity conversion: cast back to source type.
195  ValueRange recastValue = buildUnrealizedCast(
196  rewriter, originalType,
197  ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
198  CastKind::Source);
199  assert(recastValue.size() == 1);
200  recastValues.push_back(recastValue.front());
201  }
202  convertedValueIt += numConvertedValues;
203  }
204 
205  return recastValues;
206 }
207 
208 void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
209  const OneToNTypeMapping &resultMapping) {
210  // Create a cast back to the original types and replace the results of the
211  // original op with those.
212  assert(newValues.size() == resultMapping.getConvertedTypes().size());
213  assert(op->getResultTypes() == resultMapping.getOriginalTypes());
216  SmallVector<Value> castResults =
217  buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
218  replaceOp(op, castResults);
219 }
220 
221 Block *OneToNPatternRewriter::applySignatureConversion(
222  Block *block, OneToNTypeMapping &argumentConversion) {
224 
225  // Split the block at the beginning to get a new block to use for the
226  // updated signature.
228  argumentConversion.convertLocations(block->getArguments(), locs);
229  Block *newBlock =
230  createBlock(block, argumentConversion.getConvertedTypes(), locs);
231  replaceAllUsesWith(block, newBlock);
232 
233  // Create necessary casts in new block.
234  SmallVector<Value> castResults;
235  for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
236  TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
237  ValueRange newArgs =
238  argumentConversion.getConvertedValues(newBlock->getArguments(), i);
239  if (isIdentityConversion(arg.getType(), convertedTypes)) {
240  // Identity conversion: take argument as is.
241  assert(newArgs.size() == 1);
242  castResults.push_back(newArgs.front());
243  } else {
244  // Non-identity conversion: cast the converted arguments to the original
245  // type.
247  setInsertionPointToStart(newBlock);
248  ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
250  assert(castResult.size() == 1);
251  castResults.push_back(castResult.front());
252  }
253  }
254 
255  // Merge old block into new block such that we only have the latter with the
256  // new signature.
257  mergeBlocks(block, newBlock, castResults);
258 
259  return newBlock;
260 }
261 
263 OneToNConversionPattern::matchAndRewrite(Operation *op,
264  PatternRewriter &rewriter) const {
265  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
266 
267  // Construct conversion mapping for results.
268  Operation::result_type_range originalResultTypes = op->getResultTypes();
269  OneToNTypeMapping resultMapping(originalResultTypes);
270  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
271  resultMapping)))
272  return failure();
273 
274  // Construct conversion mapping for operands.
275  Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
276  OneToNTypeMapping operandMapping(originalOperandTypes);
277  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
278  operandMapping)))
279  return failure();
280 
281  // Cast operands to target types.
283  op->getOperands(), operandMapping, rewriter, CastKind::Target);
284 
285  // Create a `OneToNPatternRewriter` for the pattern, which provides additional
286  // functionality.
287  // TODO(ingomueller): I guess it would be better to use only one rewriter
288  // throughout the whole pass, but that would require to
289  // drive the pattern application ourselves, which is a lot
290  // of additional boilerplate code. This seems to work fine,
291  // so I leave it like this for the time being.
292  OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(),
293  rewriter.getListener());
294  oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
295 
296  // Apply actual pattern.
297  if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
298  resultMapping, convertedOperands)))
299  return failure();
300 
301  return success();
302 }
303 
304 namespace mlir {
305 
306 // This function applies the provided patterns using
307 // `applyPatternsAndFoldGreedily` and then replaces all newly inserted
308 // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
309 // from target to source types inserted by a `OneToNConversionPattern` normally
310 // fold away with the "forward" casts from source to target types inserted by
311 // the next pattern.) To understand which casts are "newly inserted", all casts
312 // inserted by this pass are annotated with a string attribute that also
313 // documents which kind of the cast (source, argument, or target).
316  const FrozenRewritePatternSet &patterns) {
317 #ifndef NDEBUG
318  // Remember existing unrealized casts. This data structure is only used in
319  // asserts; building it only for that purpose may be an overkill.
320  SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
321  op->walk([&](UnrealizedConversionCastOp castOp) {
322  assert(!castOp->hasAttr(castKindAttrName));
323  existingCasts.insert(castOp);
324  });
325 #endif // NDEBUG
326 
327  // Apply provided conversion patterns.
328  if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
329  emitError(op->getLoc()) << "failed to apply conversion patterns";
330  return failure();
331  }
332 
333  // Find all unrealized casts inserted by the pass that haven't folded away.
335  op->walk([&](UnrealizedConversionCastOp castOp) {
336  if (castOp->hasAttr(castKindAttrName)) {
337  assert(!existingCasts.contains(castOp));
338  worklist.push_back(castOp);
339  }
340  });
341 
342  // Replace new casts with user materializations.
343  IRRewriter rewriter(op->getContext());
344  for (UnrealizedConversionCastOp castOp : worklist) {
345  TypeRange resultTypes = castOp->getResultTypes();
346  ValueRange operands = castOp->getOperands();
347  StringRef castKind =
348  castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
349  rewriter.setInsertionPoint(castOp);
350 
351 #ifndef NDEBUG
352  // Determine whether operands or results are already legal to test some
353  // assumptions for the different kind of materializations. These properties
354  // are only used it asserts and it may be overkill to compute them.
355  bool areOperandTypesLegal = llvm::all_of(
356  operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
357  bool areResultsTypesLegal = llvm::all_of(
358  resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
359 #endif // NDEBUG
360 
361  // Add materialization and remember materialized results.
362  SmallVector<Value> materializedResults;
363  if (castKind == getCastKindName(CastKind::Target)) {
364  // Target materialization.
365  assert(!areOperandTypesLegal && areResultsTypesLegal &&
366  operands.size() == 1 && "found unexpected target cast");
367  std::optional<SmallVector<Value>> maybeResults =
368  typeConverter.materializeTargetConversion(
369  rewriter, castOp->getLoc(), resultTypes, operands.front());
370  if (!maybeResults) {
371  emitError(castOp->getLoc())
372  << "failed to create target materialization";
373  return failure();
374  }
375  materializedResults = maybeResults.value();
376  } else {
377  // Source and argument materializations.
378  assert(areOperandTypesLegal && !areResultsTypesLegal &&
379  resultTypes.size() == 1 && "found unexpected cast");
380  std::optional<Value> maybeResult;
381  if (castKind == getCastKindName(CastKind::Source)) {
382  // Source materialization.
383  maybeResult = typeConverter.materializeSourceConversion(
384  rewriter, castOp->getLoc(), resultTypes.front(),
385  castOp.getOperands());
386  } else {
387  // Argument materialization.
388  assert(castKind == getCastKindName(CastKind::Argument) &&
389  "unexpected value of cast kind attribute");
390  assert(llvm::all_of(operands,
391  [&](Value v) { return isa<BlockArgument>(v); }));
392  maybeResult = typeConverter.materializeArgumentConversion(
393  rewriter, castOp->getLoc(), resultTypes.front(),
394  castOp.getOperands());
395  }
396  if (!maybeResult.has_value() || !maybeResult.value()) {
397  emitError(castOp->getLoc())
398  << "failed to create " << castKind << " materialization";
399  return failure();
400  }
401  materializedResults = {maybeResult.value()};
402  }
403 
404  // Replace the cast with the result of the materialization.
405  rewriter.replaceOp(castOp, materializedResults);
406  }
407 
408  return success();
409 }
410 
411 } // namespace mlir
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:201
static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, CastKind kind)
Builds an UnrealizedConversionCastOp from the given inputs to the given result types.
static SmallVector< Value > buildUnrealizedForwardCasts(ValueRange originalValues, OneToNTypeMapping &conversion, RewriterBase &rewriter, CastKind kind)
Builds one UnrealizedConversionCastOp for each of the given original values using the respective targ...
StringRef getCastKindName(CastKind kind)
Mapping of enum values to string values.
static bool isIdentityConversion(Type originalType, TypeRange convertedTypes)
static SmallVector< Value > buildUnrealizedBackwardsCasts(ValueRange convertedValues, const OneToNTypeMapping &typeConversion, RewriterBase &rewriter)
Builds one UnrealizedConversionCastOp for each sequence of the given original values to one value of ...
static const char *const castKindAttrName
Attribute name that is used to annotate inserted unrealized casts with their kind (source,...
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Specialization of PatternRewriter that OneToNConversionPatterns use.
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:370
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:375
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getTypes() const
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26