MLIR  19.0.0git
OneToNTypeConversion.cpp
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "llvm/ADT/SmallSet.h"
13 
14 #include <unordered_map>
15 
16 using namespace llvm;
17 using namespace mlir;
18 
19 std::optional<SmallVector<Value>>
20 OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
21  Location loc,
22  TypeRange resultTypes,
23  Value input) const {
24  for (const OneToNMaterializationCallbackFn &fn :
25  llvm::reverse(oneToNTargetMaterializations)) {
26  if (std::optional<SmallVector<Value>> result =
27  fn(builder, resultTypes, input, loc))
28  return *result;
29  }
30  return std::nullopt;
31 }
32 
33 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
34  TypeRange convertedTypes = getConvertedTypes();
35  if (auto mapping = getInputMapping(originalTypeNo))
36  return convertedTypes.slice(mapping->inputNo, mapping->size);
37  return {};
38 }
39 
41 OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
42  unsigned originalValueNo) const {
43  if (auto mapping = getInputMapping(originalValueNo))
44  return convertedValues.slice(mapping->inputNo, mapping->size);
45  return {};
46 }
47 
48 void OneToNTypeMapping::convertLocation(
49  Value originalValue, unsigned originalValueNo,
50  llvm::SmallVectorImpl<Location> &result) const {
51  if (auto mapping = getInputMapping(originalValueNo))
52  result.append(mapping->size, originalValue.getLoc());
53 }
54 
55 void OneToNTypeMapping::convertLocations(
56  ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
57  assert(originalValues.size() == getOriginalTypes().size());
58  for (auto [i, value] : llvm::enumerate(originalValues))
59  convertLocation(value, i, result);
60 }
61 
62 static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
63  return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
64 }
65 
66 bool OneToNTypeMapping::hasNonIdentityConversion() const {
67  // XXX: I think that the original types and the converted types are the same
68  // iff there was no non-identity type conversion. If that is true, the
69  // patterns could actually test whether there is anything useful to do
70  // without having access to the signature conversion.
71  for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
72  TypeRange types = getConvertedTypes(i);
73  if (!isIdentityConversion(originalType, types)) {
74  assert(TypeRange(originalTypes) != getConvertedTypes());
75  return true;
76  }
77  }
78  assert(TypeRange(originalTypes) == getConvertedTypes());
79  return false;
80 }
81 
82 namespace {
83 enum class CastKind {
84  // Casts block arguments in the target type back to the source type. (If
85  // necessary, this cast becomes an argument materialization.)
86  Argument,
87 
88  // Casts other values in the target type back to the source type. (If
89  // necessary, this cast becomes a source materialization.)
90  Source,
91 
92  // Casts values in the source type to the target type. (If necessary, this
93  // cast becomes a target materialization.)
94  Target
95 };
96 } // namespace
97 
98 /// Mapping of enum values to string values.
99 StringRef getCastKindName(CastKind kind) {
100  static const std::unordered_map<CastKind, StringRef> castKindNames = {
101  {CastKind::Argument, "argument"},
102  {CastKind::Source, "source"},
103  {CastKind::Target, "target"}};
104  return castKindNames.at(kind);
105 }
106 
107 /// Attribute name that is used to annotate inserted unrealized casts with their
108 /// kind (source, argument, or target).
109 static const char *const castKindAttrName =
110  "__one-to-n-type-conversion_cast-kind__";
111 
112 /// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
113 /// result types. Returns the result values of the cast.
114 static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
115  ValueRange inputs, CastKind kind) {
116  // Special case: 1-to-N conversion with N = 0. No need to build an
117  // UnrealizedConversionCastOp because the op will always be dead.
118  if (resultTypes.empty())
119  return ValueRange();
120 
121  // Create cast.
122  Location loc = builder.getUnknownLoc();
123  if (!inputs.empty())
124  loc = inputs.front().getLoc();
125  auto castOp =
126  builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
127 
128  // Store cast kind as attribute.
129  auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
130  castOp->setAttr(castKindAttrName, kindAttr);
131 
132  return castOp->getResults();
133 }
134 
135 /// Builds one `UnrealizedConversionCastOp` for each of the given original
136 /// values using the respective target types given in the provided conversion
137 /// mapping and returns the results of these casts. If the conversion mapping of
138 /// a value maps a type to itself (i.e., is an identity conversion), then no
139 /// cast is inserted and the original value is returned instead.
140 /// Note that these unrealized casts are different from target materializations
141 /// in that they are *always* inserted, even if they immediately fold away, such
142 /// that patterns always see valid intermediate IR, whereas materializations are
143 /// only used in the places where the unrealized casts *don't* fold away.
144 static SmallVector<Value>
146  OneToNTypeMapping &conversion,
147  RewriterBase &rewriter, CastKind kind) {
148 
149  // Convert each operand one by one.
150  SmallVector<Value> convertedValues;
151  convertedValues.reserve(conversion.getConvertedTypes().size());
152  for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
153  TypeRange convertedTypes = conversion.getConvertedTypes(idx);
154 
155  // Identity conversion: keep operand as is.
156  if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
157  convertedValues.push_back(originalValue);
158  continue;
159  }
160 
161  // Non-identity conversion: materialize target types.
162  ValueRange castResult =
163  buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
164  convertedValues.append(castResult.begin(), castResult.end());
165  }
166 
167  return convertedValues;
168 }
169 
170 /// Builds one `UnrealizedConversionCastOp` for each sequence of the given
171 /// original values to one value of the type they originated from, i.e., a
172 /// "reverse" conversion from N converted values back to one value of the
173 /// original type, using the given (forward) type conversion. If a given value
174 /// was mapped to a value of the same type (i.e., the conversion in the mapping
175 /// is an identity conversion), then the "converted" value is returned without
176 /// cast.
177 /// Note that these unrealized casts are different from source materializations
178 /// in that they are *always* inserted, even if they immediately fold away, such
179 /// that patterns always see valid intermediate IR, whereas materializations are
180 /// only used in the places where the unrealized casts *don't* fold away.
181 static SmallVector<Value>
183  const OneToNTypeMapping &typeConversion,
184  RewriterBase &rewriter) {
185  assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
186 
187  // Create unrealized cast op for each converted result of the op.
188  SmallVector<Value> recastValues;
189  TypeRange originalTypes = typeConversion.getOriginalTypes();
190  recastValues.reserve(originalTypes.size());
191  auto convertedValueIt = convertedValues.begin();
192  for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
193  TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
194  size_t numConvertedValues = convertedTypes.size();
195  if (isIdentityConversion(originalType, convertedTypes)) {
196  // Identity conversion: take result as is.
197  recastValues.push_back(*convertedValueIt);
198  } else {
199  // Non-identity conversion: cast back to source type.
200  ValueRange recastValue = buildUnrealizedCast(
201  rewriter, originalType,
202  ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
203  CastKind::Source);
204  assert(recastValue.size() == 1);
205  recastValues.push_back(recastValue.front());
206  }
207  convertedValueIt += numConvertedValues;
208  }
209 
210  return recastValues;
211 }
212 
213 void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
214  const OneToNTypeMapping &resultMapping) {
215  // Create a cast back to the original types and replace the results of the
216  // original op with those.
217  assert(newValues.size() == resultMapping.getConvertedTypes().size());
218  assert(op->getResultTypes() == resultMapping.getOriginalTypes());
221  SmallVector<Value> castResults =
222  buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
223  replaceOp(op, castResults);
224 }
225 
226 Block *OneToNPatternRewriter::applySignatureConversion(
227  Block *block, OneToNTypeMapping &argumentConversion) {
229 
230  // Split the block at the beginning to get a new block to use for the
231  // updated signature.
233  argumentConversion.convertLocations(block->getArguments(), locs);
234  Block *newBlock =
235  createBlock(block, argumentConversion.getConvertedTypes(), locs);
236  replaceAllUsesWith(block, newBlock);
237 
238  // Create necessary casts in new block.
239  SmallVector<Value> castResults;
240  for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
241  TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
242  ValueRange newArgs =
243  argumentConversion.getConvertedValues(newBlock->getArguments(), i);
244  if (isIdentityConversion(arg.getType(), convertedTypes)) {
245  // Identity conversion: take argument as is.
246  assert(newArgs.size() == 1);
247  castResults.push_back(newArgs.front());
248  } else {
249  // Non-identity conversion: cast the converted arguments to the original
250  // type.
252  setInsertionPointToStart(newBlock);
253  ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
255  assert(castResult.size() == 1);
256  castResults.push_back(castResult.front());
257  }
258  }
259 
260  // Merge old block into new block such that we only have the latter with the
261  // new signature.
262  mergeBlocks(block, newBlock, castResults);
263 
264  return newBlock;
265 }
266 
268 OneToNConversionPattern::matchAndRewrite(Operation *op,
269  PatternRewriter &rewriter) const {
270  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
271 
272  // Construct conversion mapping for results.
273  Operation::result_type_range originalResultTypes = op->getResultTypes();
274  OneToNTypeMapping resultMapping(originalResultTypes);
275  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
276  resultMapping)))
277  return failure();
278 
279  // Construct conversion mapping for operands.
280  Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
281  OneToNTypeMapping operandMapping(originalOperandTypes);
282  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
283  operandMapping)))
284  return failure();
285 
286  // Cast operands to target types.
288  op->getOperands(), operandMapping, rewriter, CastKind::Target);
289 
290  // Create a `OneToNPatternRewriter` for the pattern, which provides additional
291  // functionality.
292  // TODO(ingomueller): I guess it would be better to use only one rewriter
293  // throughout the whole pass, but that would require to
294  // drive the pattern application ourselves, which is a lot
295  // of additional boilerplate code. This seems to work fine,
296  // so I leave it like this for the time being.
297  OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(),
298  rewriter.getListener());
299  oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
300 
301  // Apply actual pattern.
302  if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
303  resultMapping, convertedOperands)))
304  return failure();
305 
306  return success();
307 }
308 
309 namespace mlir {
310 
311 // This function applies the provided patterns using
312 // `applyPatternsAndFoldGreedily` and then replaces all newly inserted
313 // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
314 // from target to source types inserted by a `OneToNConversionPattern` normally
315 // fold away with the "forward" casts from source to target types inserted by
316 // the next pattern.) To understand which casts are "newly inserted", all casts
317 // inserted by this pass are annotated with a string attribute that also
318 // documents which kind of the cast (source, argument, or target).
321  const FrozenRewritePatternSet &patterns) {
322 #ifndef NDEBUG
323  // Remember existing unrealized casts. This data structure is only used in
324  // asserts; building it only for that purpose may be an overkill.
325  SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
326  op->walk([&](UnrealizedConversionCastOp castOp) {
327  assert(!castOp->hasAttr(castKindAttrName));
328  existingCasts.insert(castOp);
329  });
330 #endif // NDEBUG
331 
332  // Apply provided conversion patterns.
333  if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
334  emitError(op->getLoc()) << "failed to apply conversion patterns";
335  return failure();
336  }
337 
338  // Find all unrealized casts inserted by the pass that haven't folded away.
340  op->walk([&](UnrealizedConversionCastOp castOp) {
341  if (castOp->hasAttr(castKindAttrName)) {
342  assert(!existingCasts.contains(castOp));
343  worklist.push_back(castOp);
344  }
345  });
346 
347  // Replace new casts with user materializations.
348  IRRewriter rewriter(op->getContext());
349  for (UnrealizedConversionCastOp castOp : worklist) {
350  TypeRange resultTypes = castOp->getResultTypes();
351  ValueRange operands = castOp->getOperands();
352  StringRef castKind =
353  castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
354  rewriter.setInsertionPoint(castOp);
355 
356 #ifndef NDEBUG
357  // Determine whether operands or results are already legal to test some
358  // assumptions for the different kind of materializations. These properties
359  // are only used it asserts and it may be overkill to compute them.
360  bool areOperandTypesLegal = llvm::all_of(
361  operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
362  bool areResultsTypesLegal = llvm::all_of(
363  resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
364 #endif // NDEBUG
365 
366  // Add materialization and remember materialized results.
367  SmallVector<Value> materializedResults;
368  if (castKind == getCastKindName(CastKind::Target)) {
369  // Target materialization.
370  assert(!areOperandTypesLegal && areResultsTypesLegal &&
371  operands.size() == 1 && "found unexpected target cast");
372  std::optional<SmallVector<Value>> maybeResults =
373  typeConverter.materializeTargetConversion(
374  rewriter, castOp->getLoc(), resultTypes, operands.front());
375  if (!maybeResults) {
376  emitError(castOp->getLoc())
377  << "failed to create target materialization";
378  return failure();
379  }
380  materializedResults = maybeResults.value();
381  } else {
382  // Source and argument materializations.
383  assert(areOperandTypesLegal && !areResultsTypesLegal &&
384  resultTypes.size() == 1 && "found unexpected cast");
385  std::optional<Value> maybeResult;
386  if (castKind == getCastKindName(CastKind::Source)) {
387  // Source materialization.
388  maybeResult = typeConverter.materializeSourceConversion(
389  rewriter, castOp->getLoc(), resultTypes.front(),
390  castOp.getOperands());
391  } else {
392  // Argument materialization.
393  assert(castKind == getCastKindName(CastKind::Argument) &&
394  "unexpected value of cast kind attribute");
395  assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
396  maybeResult = typeConverter.materializeArgumentConversion(
397  rewriter, castOp->getLoc(), resultTypes.front(),
398  castOp.getOperands());
399  }
400  if (!maybeResult.has_value() || !maybeResult.value()) {
401  emitError(castOp->getLoc())
402  << "failed to create " << castKind << " materialization";
403  return failure();
404  }
405  materializedResults = {maybeResult.value()};
406  }
407 
408  // Replace the cast with the result of the materialization.
409  rewriter.replaceOp(castOp, materializedResults);
410  }
411 
412  return success();
413 }
414 
415 } // namespace mlir
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Definition: IRCore.cpp:210
static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, CastKind kind)
Builds an UnrealizedConversionCastOp from the given inputs to the given result types.
static SmallVector< Value > buildUnrealizedForwardCasts(ValueRange originalValues, OneToNTypeMapping &conversion, RewriterBase &rewriter, CastKind kind)
Builds one UnrealizedConversionCastOp for each of the given original values using the respective targ...
StringRef getCastKindName(CastKind kind)
Mapping of enum values to string values.
static bool isIdentityConversion(Type originalType, TypeRange convertedTypes)
static SmallVector< Value > buildUnrealizedBackwardsCasts(ValueRange convertedValues, const OneToNTypeMapping &typeConversion, RewriterBase &rewriter)
Builds one UnrealizedConversionCastOp for each sequence of the given original values to one value of ...
static const char *const castKindAttrName
Attribute name that is used to annotate inserted unrealized casts with their kind (source,...
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Specialization of PatternRewriter that OneToNConversionPatterns use.
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:387
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:392
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26