12 #include "llvm/ADT/SmallSet.h"
14 #include <unordered_map>
19 std::optional<SmallVector<Value>>
20 OneToNTypeConverter::materializeTargetConversion(
OpBuilder &builder,
25 llvm::reverse(oneToNTargetMaterializations)) {
27 fn(builder, resultTypes, input, loc))
33 TypeRange OneToNTypeMapping::getConvertedTypes(
unsigned originalTypeNo)
const {
34 TypeRange convertedTypes = getConvertedTypes();
35 if (
auto mapping = getInputMapping(originalTypeNo))
36 return convertedTypes.slice(mapping->inputNo, mapping->size);
41 OneToNTypeMapping::getConvertedValues(
ValueRange convertedValues,
42 unsigned originalValueNo)
const {
43 if (
auto mapping = getInputMapping(originalValueNo))
44 return convertedValues.slice(mapping->inputNo, mapping->size);
48 void OneToNTypeMapping::convertLocation(
49 Value originalValue,
unsigned originalValueNo,
51 if (
auto mapping = getInputMapping(originalValueNo))
52 result.append(mapping->size, originalValue.
getLoc());
55 void OneToNTypeMapping::convertLocations(
57 assert(originalValues.size() == getOriginalTypes().size());
59 convertLocation(value, i, result);
63 return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
66 bool OneToNTypeMapping::hasNonIdentityConversion()
const {
74 assert(
TypeRange(originalTypes) != getConvertedTypes());
78 assert(
TypeRange(originalTypes) == getConvertedTypes());
100 static const std::unordered_map<CastKind, StringRef> castKindNames = {
102 {CastKind::Source,
"source"},
103 {CastKind::Target,
"target"}};
104 return castKindNames.at(kind);
110 "__one-to-n-type-conversion_cast-kind__";
118 if (resultTypes.empty())
124 loc = inputs.front().getLoc();
126 builder.
create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
132 return castOp->getResults();
157 convertedValues.push_back(originalValue);
164 convertedValues.append(castResult.begin(), castResult.end());
167 return convertedValues;
190 recastValues.reserve(originalTypes.size());
191 auto convertedValueIt = convertedValues.begin();
194 size_t numConvertedValues = convertedTypes.size();
197 recastValues.push_back(*convertedValueIt);
201 rewriter, originalType,
202 ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
204 assert(recastValue.size() == 1);
205 recastValues.push_back(recastValue.front());
207 convertedValueIt += numConvertedValues;
223 replaceOp(op, castResults);
226 Block *OneToNPatternRewriter::applySignatureConversion(
236 replaceAllUsesWith(block, newBlock);
246 assert(newArgs.size() == 1);
247 castResults.push_back(newArgs.front());
255 assert(castResult.size() == 1);
256 castResults.push_back(castResult.front());
262 mergeBlocks(block, newBlock, castResults);
270 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
275 if (
failed(typeConverter->computeTypeMapping(originalResultTypes,
282 if (
failed(typeConverter->computeTypeMapping(originalOperandTypes,
288 op->
getOperands(), operandMapping, rewriter, CastKind::Target);
302 if (
failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
303 resultMapping, convertedOperands)))
325 SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
326 op->
walk([&](UnrealizedConversionCastOp castOp) {
328 existingCasts.insert(castOp);
340 op->
walk([&](UnrealizedConversionCastOp castOp) {
342 assert(!existingCasts.contains(castOp));
343 worklist.push_back(castOp);
349 for (UnrealizedConversionCastOp castOp : worklist) {
350 TypeRange resultTypes = castOp->getResultTypes();
360 bool areOperandTypesLegal = llvm::all_of(
361 operands.
getTypes(), [&](
Type t) { return typeConverter.isLegal(t); });
362 bool areResultsTypesLegal = llvm::all_of(
363 resultTypes, [&](
Type t) {
return typeConverter.
isLegal(t); });
370 assert(!areOperandTypesLegal && areResultsTypesLegal &&
371 operands.size() == 1 &&
"found unexpected target cast");
372 std::optional<SmallVector<Value>> maybeResults =
374 rewriter, castOp->getLoc(), resultTypes, operands.front());
377 <<
"failed to create target materialization";
380 materializedResults = maybeResults.value();
383 assert(areOperandTypesLegal && !areResultsTypesLegal &&
384 resultTypes.size() == 1 &&
"found unexpected cast");
385 std::optional<Value> maybeResult;
389 rewriter, castOp->getLoc(), resultTypes.front(),
390 castOp.getOperands());
394 "unexpected value of cast kind attribute");
395 assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
397 rewriter, castOp->getLoc(), resultTypes.front(),
398 castOp.getOperands());
400 if (!maybeResult.has_value() || !maybeResult.value()) {
402 <<
"failed to create " << castKind <<
" materialization";
405 materializedResults = {maybeResult.value()};
409 rewriter.
replaceOp(castOp, materializedResults);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Specialization of PatternRewriter that OneToNConversionPatterns use.
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.