17 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
18 #include "mlir/Dialect/Linalg/Passes.h.inc"
24 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
34 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
38 bool noneInvalid = llvm::none_of(types, [](
Type t) {
42 return anyRankedTensor && noneInvalid;
65 res.reserve(rankedTensorTypes.size());
66 for (
Type t : rankedTensorTypes) {
69 for (
Value v : operands) {
70 if (v.getType() == t) {
80 res.push_back(tensor::EmptyOp::create(
82 cast<RankedTensorType>(t).getElementType()));
88 struct ConvertAnyElementwiseMappableOpOnRankedTensors :
public RewritePattern {
89 ConvertAnyElementwiseMappableOpOnRankedTensors(
MLIRContext *context)
91 LogicalResult matchAndRewrite(
Operation *op,
94 return rewriter.notifyMatchFailure(
95 op,
"requires elementwise op on ranked tensors");
97 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
98 auto rank = resTy.getRank();
102 {}, rewriter.getContext());
107 isScalarOperand.reserve(op->getNumOperands());
108 for (
Type ty : op->getOperandTypes()) {
110 isScalarOperand.push_back(
true);
111 else if (
auto rt = dyn_cast<RankedTensorType>(ty))
112 isScalarOperand.push_back(
false);
114 return rewriter.notifyMatchFailure(
116 "unsupported operand type (expected scalar-like or ranked tensor)");
121 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
123 for (
bool isScalar : isScalarOperand)
124 indexingMaps.push_back(isScalar ? scalarMap : idMap);
126 indexingMaps.append(op->getNumResults(), idMap);
129 rank, utils::IteratorType::parallel);
132 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
133 op, op->getResultTypes(),
141 llvm::map_range(op->getResultTypes(), [](
Type type) {
142 return cast<TensorType>(type).getElementType();
145 builder.
create(loc, op->getName().getIdentifier(),
146 regionArgs.take_front(op->getNumOperands()),
147 resultEltTys, op->getAttrs());
148 linalg::YieldOp::create(builder, loc, scalarOp->
getResults());
157 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
162 class ConvertElementwiseToLinalgPass
163 :
public impl::ConvertElementwiseToLinalgPassBase<
164 ConvertElementwiseToLinalgPass> {
165 using impl::ConvertElementwiseToLinalgPassBase<
166 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
168 void runOnOperation() final {
169 auto *func = getOperation();
175 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.