13 #include "llvm/ADT/SmallSet.h"
15 #include <unordered_map>
20 TypeRange OneToNTypeMapping::getConvertedTypes(
unsigned originalTypeNo)
const {
21 TypeRange convertedTypes = getConvertedTypes();
22 if (
auto mapping = getInputMapping(originalTypeNo))
23 return convertedTypes.slice(mapping->inputNo, mapping->size);
28 OneToNTypeMapping::getConvertedValues(
ValueRange convertedValues,
29 unsigned originalValueNo)
const {
30 if (
auto mapping = getInputMapping(originalValueNo))
31 return convertedValues.slice(mapping->inputNo, mapping->size);
35 void OneToNTypeMapping::convertLocation(
36 Value originalValue,
unsigned originalValueNo,
38 if (
auto mapping = getInputMapping(originalValueNo))
39 result.append(mapping->size, originalValue.
getLoc());
42 void OneToNTypeMapping::convertLocations(
44 assert(originalValues.size() == getOriginalTypes().size());
46 convertLocation(value, i, result);
50 return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
53 bool OneToNTypeMapping::hasNonIdentityConversion()
const {
61 assert(
TypeRange(originalTypes) != getConvertedTypes());
65 assert(
TypeRange(originalTypes) == getConvertedTypes());
87 static const std::unordered_map<CastKind, StringRef> castKindNames = {
89 {CastKind::Source,
"source"},
90 {CastKind::Target,
"target"}};
91 return castKindNames.at(kind);
97 "__one-to-n-type-conversion_cast-kind__";
105 if (resultTypes.empty())
111 loc = inputs.front().getLoc();
113 builder.
create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
119 return castOp->getResults();
144 convertedValues.push_back(originalValue);
151 convertedValues.append(castResult.begin(), castResult.end());
154 return convertedValues;
177 recastValues.reserve(originalTypes.size());
178 auto convertedValueIt = convertedValues.begin();
181 size_t numConvertedValues = convertedTypes.size();
184 recastValues.push_back(*convertedValueIt);
188 rewriter, originalType,
189 ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
191 assert(recastValue.size() == 1);
192 recastValues.push_back(recastValue.front());
194 convertedValueIt += numConvertedValues;
210 replaceOp(op, castResults);
213 Block *OneToNPatternRewriter::applySignatureConversion(
223 replaceAllUsesWith(block, newBlock);
233 assert(newArgs.size() == 1);
234 castResults.push_back(newArgs.front());
242 assert(castResult.size() == 1);
243 castResults.push_back(castResult.front());
249 mergeBlocks(block, newBlock, castResults);
275 op->
getOperands(), operandMapping, rewriter, CastKind::Target);
289 if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
290 resultMapping, convertedOperands)))
312 SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
313 op->
walk([&](UnrealizedConversionCastOp castOp) {
315 existingCasts.insert(castOp);
327 op->
walk([&](UnrealizedConversionCastOp castOp) {
329 assert(!existingCasts.contains(castOp));
330 worklist.push_back(castOp);
336 for (UnrealizedConversionCastOp castOp : worklist) {
337 TypeRange resultTypes = castOp->getResultTypes();
347 bool areOperandTypesLegal = llvm::all_of(
348 operands.
getTypes(), [&](
Type t) { return typeConverter.isLegal(t); });
349 bool areResultsTypesLegal = llvm::all_of(
357 assert(!areOperandTypesLegal && areResultsTypesLegal &&
358 operands.size() == 1 &&
"found unexpected target cast");
360 rewriter, castOp->getLoc(), resultTypes, operands.front());
361 if (materializedResults.empty()) {
363 <<
"failed to create target materialization";
368 assert(areOperandTypesLegal && !areResultsTypesLegal &&
369 resultTypes.size() == 1 &&
"found unexpected cast");
370 std::optional<Value> maybeResult;
374 rewriter, castOp->getLoc(), resultTypes.front(),
375 castOp.getOperands());
379 "unexpected value of cast kind attribute");
380 assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
382 rewriter, castOp->getLoc(), resultTypes.front(),
383 castOp.getOperands());
385 if (!maybeResult.has_value() || !maybeResult.value()) {
387 <<
"failed to create " << castKind <<
" materialization";
390 materializedResults = {maybeResult.value()};
394 rewriter.
replaceOp(castOp, materializedResults);
403 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
412 ValueRange convertedOperands)
const override {
413 auto funcOp = cast<FunctionOpInterface>(op);
430 if (!argumentMapping.hasNonIdentityConversion() &&
431 !funcResultMapping.hasNonIdentityConversion())
436 argumentMapping.getConvertedTypes(),
437 funcResultMapping.getConvertedTypes());
441 if (!funcOp.isExternal()) {
442 Region *region = &funcOp.getFunctionBody();
455 patterns.
add<FunctionOpInterfaceSignatureConversion>(
static void setInsertionPointAfter(OpBuilder &b, Value value)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MlirBlock createBlock(const nb::sequence &pyArgTypes, const std::optional< nb::sequence > &pyArgLocs)
Create a block, using the current location context if no locations are specified.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Base class for patterns with 1:N type conversions.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
RewritePatternSet & patterns
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
TypeConverter & typeConverter
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
const TypeConverter & converter
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, const TypeConverter &converter, RewritePatternSet &patterns)