MLIR 22.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15
16namespace mlir {
17#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
18#include "mlir/Dialect/Linalg/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23static inline bool isScalarLike(Type t) {
24 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
25}
26
29 return false;
30
31 auto types = op->getOperandTypes();
32
33 // We want at least one ranked tensor.
34 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
35
36 // No invalid operands (i.e., every operand is a ranked tensor or
37 // scalar-like).
38 bool noneInvalid = llvm::none_of(types, [](Type t) {
39 return !(isa<RankedTensorType>(t) || isScalarLike(t));
40 });
41
42 return anyRankedTensor && noneInvalid;
45/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
46/// the result types and return a list of values such that, for each result type
47/// `t` and value `v` at the same index `idx`:
48/// 1. `v.getType() == t`
49/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
50/// such operand. Then`v == operand_first`.
51/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
52/// a. Static and dynamic dims extracted from the first operand of `op`.
53/// b. Elemental type equal to the elemental type of `t`.
54///
55/// This is sufficient because ElementwiseMappable guarantees that "The static
56/// types of all vector (resp. tensor) operands and results must have the same
57/// shape".
61 Location loc = op->getLoc();
62 ValueRange operands = op->getOperands();
63 TypeRange rankedTensorTypes = op->getResultTypes();
65 res.reserve(rankedTensorTypes.size());
66 for (Type t : rankedTensorTypes) {
67 // Try to find an operand with type matching the result tensor.
68 bool found = false;
69 for (Value v : operands) {
70 if (v.getType() == t) {
71 found = true;
72 res.push_back(v);
73 break;
74 }
75 }
76 if (found)
77 continue;
78
79 // Extract static / dynamic shape mix from the first operand.
80 res.push_back(tensor::EmptyOp::create(
81 b, loc, tensor::getMixedSizes(b, loc, operands.front()),
82 cast<RankedTensorType>(t).getElementType()));
83 }
84 return res;
85}
86
87namespace {
88struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
89 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
90 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
91 LogicalResult matchAndRewrite(Operation *op,
92 PatternRewriter &rewriter) const final {
94 return rewriter.notifyMatchFailure(
95 op, "requires elementwise op on ranked tensors");
96
97 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
98 auto rank = resTy.getRank();
99
100 // Maps: identity for tensors (rank > 0), scalar map for scalars.
101 AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
102 /*results=*/{}, rewriter.getContext());
103 AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
104
105 // Match phase.
106 SmallVector<bool> isScalarOperand;
107 isScalarOperand.reserve(op->getNumOperands());
108 for (Type ty : op->getOperandTypes()) {
109 if (isScalarLike(ty))
110 isScalarOperand.push_back(true);
111 else if (auto rt = dyn_cast<RankedTensorType>(ty))
112 isScalarOperand.push_back(false);
113 else
114 return rewriter.notifyMatchFailure(
115 op,
116 "unsupported operand type (expected scalar-like or ranked tensor)");
117 }
118
119 // Create indexing maps.
120 SmallVector<AffineMap> indexingMaps;
121 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
122
123 for (bool isScalar : isScalarOperand)
124 indexingMaps.push_back(isScalar ? scalarMap : idMap);
125
126 indexingMaps.append(op->getNumResults(), idMap);
127
128 SmallVector<utils::IteratorType> iteratorTypes(
129 rank, utils::IteratorType::parallel);
130 SmallVector<Value> outputs =
132 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
133 op, /*resultTensorTypes=*/op->getResultTypes(),
134 /*inputs=*/op->getOperands(),
135 /*outputs=*/outputs,
136 /*indexingMaps=*/indexingMaps,
137 /*iteratorTypes=*/iteratorTypes,
138 /*bodyBuilder=*/
139 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
140 SmallVector<Type> resultEltTys = llvm::to_vector<6>(
141 llvm::map_range(op->getResultTypes(), [](Type type) {
142 return cast<TensorType>(type).getElementType();
143 }));
144 Operation *scalarOp =
145 builder.create(loc, op->getName().getIdentifier(),
146 regionArgs.take_front(op->getNumOperands()),
147 resultEltTys, op->getAttrs());
148 linalg::YieldOp::create(builder, loc, scalarOp->getResults());
149 });
150 return success();
151 }
152};
153} // namespace
154
157 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
158 patterns.getContext());
159}
160
161namespace {
162class ConvertElementwiseToLinalgPass
164 ConvertElementwiseToLinalgPass> {
166 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
167
168 void runOnOperation() final {
169 auto *func = getOperation();
170 auto *context = &getContext();
171 ConversionTarget target(*context);
173
175 target.markUnknownOpDynamicallyLegal([](Operation *op) {
177 });
178
179 if (failed(applyPartialConversion(func, target, std::move(patterns))))
180 signalPassFailure();
181 }
182};
183} // namespace
return success()
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns