18 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG
19 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 [](
Type type) { return isa<RankedTensorType>(type); });
54 res.reserve(rankedTensorTypes.size());
55 for (
Type t : rankedTensorTypes) {
58 for (
Value v : operands) {
59 if (v.getType() == t) {
69 res.push_back(b.
create<tensor::EmptyOp>(
71 cast<RankedTensorType>(t).getElementType()));
77 struct ConvertAnyElementwiseMappableOpOnRankedTensors :
public RewritePattern {
78 ConvertAnyElementwiseMappableOpOnRankedTensors(
MLIRContext *context)
83 return rewriter.notifyMatchFailure(
84 op,
"requires elementwise op on ranked tensors");
89 rewriter.getMultiDimIdentityMap(rank));
91 rank, utils::IteratorType::parallel);
93 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
101 auto resultTypes = llvm::to_vector<6>(
103 return cast<TensorType>(type).getElementType();
109 builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
118 patterns.
add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
123 class ConvertElementwiseToLinalgPass
124 :
public impl::ConvertElementwiseToLinalgBase<
125 ConvertElementwiseToLinalgPass> {
127 void runOnOperation() final {
128 auto *func = getOperation();
134 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
145 return std::make_unique<ConvertElementwiseToLinalgPass>();
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createConvertElementwiseToLinalgPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.