12 #include "llvm/ADT/SmallSet.h"
14 #include <unordered_map>
19 std::optional<SmallVector<Value>>
20 OneToNTypeConverter::materializeTargetConversion(
OpBuilder &builder,
25 llvm::reverse(oneToNTargetMaterializations)) {
27 fn(builder, resultTypes, input, loc))
33 TypeRange OneToNTypeMapping::getConvertedTypes(
unsigned originalTypeNo)
const {
34 TypeRange convertedTypes = getConvertedTypes();
35 if (
auto mapping = getInputMapping(originalTypeNo))
36 return convertedTypes.slice(mapping->inputNo, mapping->size);
41 OneToNTypeMapping::getConvertedValues(
ValueRange convertedValues,
42 unsigned originalValueNo)
const {
43 if (
auto mapping = getInputMapping(originalValueNo))
44 return convertedValues.slice(mapping->inputNo, mapping->size);
48 void OneToNTypeMapping::convertLocation(
49 Value originalValue,
unsigned originalValueNo,
51 if (
auto mapping = getInputMapping(originalValueNo))
52 result.append(mapping->size, originalValue.
getLoc());
55 void OneToNTypeMapping::convertLocations(
57 assert(originalValues.size() == getOriginalTypes().size());
59 convertLocation(value, i, result);
63 return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
66 bool OneToNTypeMapping::hasNonIdentityConversion()
const {
74 assert(
TypeRange(originalTypes) != getConvertedTypes());
78 assert(
TypeRange(originalTypes) == getConvertedTypes());
100 static const std::unordered_map<CastKind, StringRef> castKindNames = {
102 {CastKind::Source,
"source"},
103 {CastKind::Target,
"target"}};
104 return castKindNames.at(kind);
110 "__one-to-n-type-conversion_cast-kind__";
119 loc = inputs.front().getLoc();
121 builder.
create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
127 return castOp->getResults();
152 convertedValues.push_back(originalValue);
159 convertedValues.append(castResult.begin(), castResult.end());
162 return convertedValues;
185 recastValues.reserve(originalTypes.size());
186 auto convertedValueIt = convertedValues.begin();
189 size_t numConvertedValues = convertedTypes.size();
192 recastValues.push_back(*convertedValueIt);
196 rewriter, originalType,
197 ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
199 assert(recastValue.size() == 1);
200 recastValues.push_back(recastValue.front());
202 convertedValueIt += numConvertedValues;
218 replaceOp(op, castResults);
221 Block *OneToNPatternRewriter::applySignatureConversion(
231 replaceAllUsesWith(block, newBlock);
241 assert(newArgs.size() == 1);
242 castResults.push_back(newArgs.front());
250 assert(castResult.size() == 1);
251 castResults.push_back(castResult.front());
257 mergeBlocks(block, newBlock, castResults);
265 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
270 if (
failed(typeConverter->computeTypeMapping(originalResultTypes,
277 if (
failed(typeConverter->computeTypeMapping(originalOperandTypes,
283 op->
getOperands(), operandMapping, rewriter, CastKind::Target);
297 if (
failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
298 resultMapping, convertedOperands)))
320 SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
321 op->
walk([&](UnrealizedConversionCastOp castOp) {
323 existingCasts.insert(castOp);
335 op->
walk([&](UnrealizedConversionCastOp castOp) {
337 assert(!existingCasts.contains(castOp));
338 worklist.push_back(castOp);
344 for (UnrealizedConversionCastOp castOp : worklist) {
345 TypeRange resultTypes = castOp->getResultTypes();
355 bool areOperandTypesLegal = llvm::all_of(
356 operands.
getTypes(), [&](
Type t) { return typeConverter.isLegal(t); });
357 bool areResultsTypesLegal = llvm::all_of(
358 resultTypes, [&](
Type t) {
return typeConverter.
isLegal(t); });
365 assert(!areOperandTypesLegal && areResultsTypesLegal &&
366 operands.size() == 1 &&
"found unexpected target cast");
367 std::optional<SmallVector<Value>> maybeResults =
369 rewriter, castOp->getLoc(), resultTypes, operands.front());
372 <<
"failed to create target materialization";
375 materializedResults = maybeResults.value();
378 assert(areOperandTypesLegal && !areResultsTypesLegal &&
379 resultTypes.size() == 1 &&
"found unexpected cast");
380 std::optional<Value> maybeResult;
384 rewriter, castOp->getLoc(), resultTypes.front(),
385 castOp.getOperands());
389 "unexpected value of cast kind attribute");
390 assert(llvm::all_of(operands,
391 [&](
Value v) {
return isa<BlockArgument>(v); }));
393 rewriter, castOp->getLoc(), resultTypes.front(),
394 castOp.getOperands());
396 if (!maybeResult.has_value() || !maybeResult.value()) {
398 <<
"failed to create " << castKind <<
" materialization";
401 materializedResults = {maybeResult.value()};
405 rewriter.
replaceOp(castOp, materializedResults);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Specialization of PatternRewriter that OneToNConversionPatterns use.
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.