13 #include "llvm/ADT/SmallSet.h"
15 #include <unordered_map>
20 std::optional<SmallVector<Value>>
21 OneToNTypeConverter::materializeTargetConversion(
OpBuilder &builder,
26 llvm::reverse(oneToNTargetMaterializations)) {
28 fn(builder, resultTypes, input, loc))
34 TypeRange OneToNTypeMapping::getConvertedTypes(
unsigned originalTypeNo)
const {
35 TypeRange convertedTypes = getConvertedTypes();
36 if (
auto mapping = getInputMapping(originalTypeNo))
37 return convertedTypes.slice(mapping->inputNo, mapping->size);
42 OneToNTypeMapping::getConvertedValues(
ValueRange convertedValues,
43 unsigned originalValueNo)
const {
44 if (
auto mapping = getInputMapping(originalValueNo))
45 return convertedValues.slice(mapping->inputNo, mapping->size);
49 void OneToNTypeMapping::convertLocation(
50 Value originalValue,
unsigned originalValueNo,
52 if (
auto mapping = getInputMapping(originalValueNo))
53 result.append(mapping->size, originalValue.
getLoc());
56 void OneToNTypeMapping::convertLocations(
58 assert(originalValues.size() == getOriginalTypes().size());
60 convertLocation(value, i, result);
64 return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
67 bool OneToNTypeMapping::hasNonIdentityConversion()
const {
75 assert(
TypeRange(originalTypes) != getConvertedTypes());
79 assert(
TypeRange(originalTypes) == getConvertedTypes());
101 static const std::unordered_map<CastKind, StringRef> castKindNames = {
103 {CastKind::Source,
"source"},
104 {CastKind::Target,
"target"}};
105 return castKindNames.at(kind);
111 "__one-to-n-type-conversion_cast-kind__";
119 if (resultTypes.empty())
125 loc = inputs.front().getLoc();
127 builder.
create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
133 return castOp->getResults();
158 convertedValues.push_back(originalValue);
165 convertedValues.append(castResult.begin(), castResult.end());
168 return convertedValues;
191 recastValues.reserve(originalTypes.size());
192 auto convertedValueIt = convertedValues.begin();
195 size_t numConvertedValues = convertedTypes.size();
198 recastValues.push_back(*convertedValueIt);
202 rewriter, originalType,
203 ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
205 assert(recastValue.size() == 1);
206 recastValues.push_back(recastValue.front());
208 convertedValueIt += numConvertedValues;
224 replaceOp(op, castResults);
227 Block *OneToNPatternRewriter::applySignatureConversion(
237 replaceAllUsesWith(block, newBlock);
247 assert(newArgs.size() == 1);
248 castResults.push_back(newArgs.front());
256 assert(castResult.size() == 1);
257 castResults.push_back(castResult.front());
263 mergeBlocks(block, newBlock, castResults);
271 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
276 if (failed(typeConverter->computeTypeMapping(originalResultTypes,
283 if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
289 op->
getOperands(), operandMapping, rewriter, CastKind::Target);
303 if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
304 resultMapping, convertedOperands)))
326 SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
327 op->
walk([&](UnrealizedConversionCastOp castOp) {
329 existingCasts.insert(castOp);
341 op->
walk([&](UnrealizedConversionCastOp castOp) {
343 assert(!existingCasts.contains(castOp));
344 worklist.push_back(castOp);
350 for (UnrealizedConversionCastOp castOp : worklist) {
351 TypeRange resultTypes = castOp->getResultTypes();
361 bool areOperandTypesLegal = llvm::all_of(
362 operands.
getTypes(), [&](
Type t) { return typeConverter.isLegal(t); });
363 bool areResultsTypesLegal = llvm::all_of(
364 resultTypes, [&](
Type t) {
return typeConverter.
isLegal(t); });
371 assert(!areOperandTypesLegal && areResultsTypesLegal &&
372 operands.size() == 1 &&
"found unexpected target cast");
373 std::optional<SmallVector<Value>> maybeResults =
375 rewriter, castOp->getLoc(), resultTypes, operands.front());
378 <<
"failed to create target materialization";
381 materializedResults = maybeResults.value();
384 assert(areOperandTypesLegal && !areResultsTypesLegal &&
385 resultTypes.size() == 1 &&
"found unexpected cast");
386 std::optional<Value> maybeResult;
390 rewriter, castOp->getLoc(), resultTypes.front(),
391 castOp.getOperands());
395 "unexpected value of cast kind attribute");
396 assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
398 rewriter, castOp->getLoc(), resultTypes.front(),
399 castOp.getOperands());
401 if (!maybeResult.has_value() || !maybeResult.value()) {
403 <<
"failed to create " << castKind <<
" materialization";
406 materializedResults = {maybeResult.value()};
410 rewriter.
replaceOp(castOp, materializedResults);
419 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
428 ValueRange convertedOperands)
const override {
429 auto funcOp = cast<FunctionOpInterface>(op);
430 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
434 if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
440 if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
446 if (!argumentMapping.hasNonIdentityConversion() &&
447 !funcResultMapping.hasNonIdentityConversion())
452 argumentMapping.getConvertedTypes(),
453 funcResultMapping.getConvertedTypes());
457 if (!funcOp.isExternal()) {
458 Region *region = &funcOp.getFunctionBody();
471 patterns.
add<FunctionOpInterfaceSignatureConversion>(
472 functionLikeOpName, patterns.
getContext(), converter);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const py::sequence &pyArgTypes, const std::optional< py::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Base class for patterns with 1:N type conversions.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, TypeConverter &converter, RewritePatternSet &patterns)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...