15#include "llvm/ADT/SmallVectorExtras.h"
18#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
19#include "mlir/Dialect/Linalg/Passes.h.inc"
25 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
35 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
39 bool noneInvalid = llvm::none_of(types, [](
Type t) {
43 return anyRankedTensor && noneInvalid;
66 res.reserve(rankedTensorTypes.size());
67 for (
Type t : rankedTensorTypes) {
70 for (
Value v : operands) {
71 if (v.getType() == t) {
81 res.push_back(tensor::EmptyOp::create(
83 cast<RankedTensorType>(t).getElementType()));
89struct ConvertAnyElementwiseMappableOpOnRankedTensors :
public RewritePattern {
90 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
91 : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
92 LogicalResult matchAndRewrite(Operation *op,
93 PatternRewriter &rewriter)
const final {
95 return rewriter.notifyMatchFailure(
96 op,
"requires elementwise op on ranked tensors");
98 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
99 auto rank = resTy.getRank();
103 {}, rewriter.getContext());
104 AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
107 SmallVector<bool> isScalarOperand;
108 isScalarOperand.reserve(op->getNumOperands());
109 for (Type ty : op->getOperandTypes()) {
111 isScalarOperand.push_back(
true);
112 else if (
auto rt = dyn_cast<RankedTensorType>(ty))
113 isScalarOperand.push_back(
false);
115 return rewriter.notifyMatchFailure(
117 "unsupported operand type (expected scalar-like or ranked tensor)");
121 SmallVector<AffineMap> indexingMaps;
122 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
124 for (
bool isScalar : isScalarOperand)
125 indexingMaps.push_back(isScalar ? scalarMap : idMap);
127 indexingMaps.append(op->getNumResults(), idMap);
129 SmallVector<utils::IteratorType> iteratorTypes(
130 rank, utils::IteratorType::parallel);
131 SmallVector<Value> outputs =
133 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
134 op, op->getResultTypes(),
140 [&](OpBuilder &builder, Location loc,
ValueRange regionArgs) {
141 SmallVector<Type> resultEltTys =
142 llvm::map_to_vector<6>(op->getResultTypes(), [](Type type) {
143 return cast<TensorType>(type).getElementType();
145 Operation *scalarOp =
146 builder.create(loc, op->getName().getIdentifier(),
147 regionArgs.take_front(op->getNumOperands()),
148 resultEltTys, op->getAttrs());
149 linalg::YieldOp::create(builder, loc, scalarOp->
getResults());
158 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
163class ConvertElementwiseToLinalgPass
164 :
public impl::ConvertElementwiseToLinalgPassBase<
165 ConvertElementwiseToLinalgPass> {
166 using impl::ConvertElementwiseToLinalgPassBase<
167 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
169 void runOnOperation() final {
170 auto *
func = getOperation();
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns