MLIR 23.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/ADT/SmallVectorExtras.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
19#include "mlir/Dialect/Linalg/Passes.h.inc"
20} // namespace mlir
21
22using namespace mlir;
23
24static inline bool isScalarLike(Type t) {
25 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
26}
27
30 return false;
31
32 auto types = op->getOperandTypes();
33
34 // We want at least one ranked tensor.
35 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
36
37 // No invalid operands (i.e., every operand is a ranked tensor or
38 // scalar-like).
39 bool noneInvalid = llvm::none_of(types, [](Type t) {
40 return !(isa<RankedTensorType>(t) || isScalarLike(t));
41 });
42
43 return anyRankedTensor && noneInvalid;
44}
45
46/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
47/// the result types and return a list of values such that, for each result type
48/// `t` and value `v` at the same index `idx`:
49/// 1. `v.getType() == t`
50/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
51/// such operand. Then`v == operand_first`.
52/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
53/// a. Static and dynamic dims extracted from the first operand of `op`.
54/// b. Elemental type equal to the elemental type of `t`.
55///
56/// This is sufficient because ElementwiseMappable guarantees that "The static
57/// types of all vector (resp. tensor) operands and results must have the same
58/// shape".
62 Location loc = op->getLoc();
63 ValueRange operands = op->getOperands();
64 TypeRange rankedTensorTypes = op->getResultTypes();
66 res.reserve(rankedTensorTypes.size());
67 for (Type t : rankedTensorTypes) {
68 // Try to find an operand with type matching the result tensor.
69 bool found = false;
70 for (Value v : operands) {
71 if (v.getType() == t) {
72 found = true;
73 res.push_back(v);
74 break;
75 }
76 }
77 if (found)
78 continue;
79
80 // Extract static / dynamic shape mix from the first operand.
81 res.push_back(tensor::EmptyOp::create(
82 b, loc, tensor::getMixedSizes(b, loc, operands.front()),
83 cast<RankedTensorType>(t).getElementType()));
84 }
85 return res;
86}
87
88namespace {
89struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
90 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
91 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
92 LogicalResult matchAndRewrite(Operation *op,
93 PatternRewriter &rewriter) const final {
95 return rewriter.notifyMatchFailure(
96 op, "requires elementwise op on ranked tensors");
97
98 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
99 auto rank = resTy.getRank();
100
101 // Maps: identity for tensors (rank > 0), scalar map for scalars.
102 AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
103 /*results=*/{}, rewriter.getContext());
104 AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
105
106 // Match phase.
107 SmallVector<bool> isScalarOperand;
108 isScalarOperand.reserve(op->getNumOperands());
109 for (Type ty : op->getOperandTypes()) {
110 if (isScalarLike(ty))
111 isScalarOperand.push_back(true);
112 else if (auto rt = dyn_cast<RankedTensorType>(ty))
113 isScalarOperand.push_back(false);
114 else
115 return rewriter.notifyMatchFailure(
116 op,
117 "unsupported operand type (expected scalar-like or ranked tensor)");
118 }
119
120 // Create indexing maps.
121 SmallVector<AffineMap> indexingMaps;
122 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
123
124 for (bool isScalar : isScalarOperand)
125 indexingMaps.push_back(isScalar ? scalarMap : idMap);
126
127 indexingMaps.append(op->getNumResults(), idMap);
128
129 SmallVector<utils::IteratorType> iteratorTypes(
130 rank, utils::IteratorType::parallel);
131 SmallVector<Value> outputs =
133 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
134 op, /*resultTensorTypes=*/op->getResultTypes(),
135 /*inputs=*/op->getOperands(),
136 /*outputs=*/outputs,
137 /*indexingMaps=*/indexingMaps,
138 /*iteratorTypes=*/iteratorTypes,
139 /*bodyBuilder=*/
140 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
141 SmallVector<Type> resultEltTys =
142 llvm::map_to_vector<6>(op->getResultTypes(), [](Type type) {
143 return cast<TensorType>(type).getElementType();
144 });
145 Operation *scalarOp =
146 builder.create(loc, op->getName().getIdentifier(),
147 regionArgs.take_front(op->getNumOperands()),
148 resultEltTys, op->getAttrs());
149 linalg::YieldOp::create(builder, loc, scalarOp->getResults());
150 });
151 return success();
152 }
153};
154} // namespace
155
157 RewritePatternSet &patterns) {
158 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
159 patterns.getContext());
160}
161
162namespace {
163class ConvertElementwiseToLinalgPass
164 : public impl::ConvertElementwiseToLinalgPassBase<
165 ConvertElementwiseToLinalgPass> {
166 using impl::ConvertElementwiseToLinalgPassBase<
167 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
168
169 void runOnOperation() final {
170 auto *func = getOperation();
171 auto *context = &getContext();
172 ConversionTarget target(*context);
173 RewritePatternSet patterns(context);
174
176 target.markUnknownOpDynamicallyLegal([](Operation *op) {
178 });
179
180 if (failed(applyPartialConversion(func, target, std::move(patterns))))
181 signalPassFailure();
182 }
183};
184} // namespace
return success()
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
operand_type_range getOperandTypes()
Definition Operation.h:426
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.