MLIR 23.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/ADT/SmallVectorExtras.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
19#include "mlir/Dialect/Linalg/Passes.h.inc"
20} // namespace mlir
21
22using namespace mlir;
23
24static inline bool isScalarLike(Type t) {
25 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
26}
27
30 return false;
31
32 auto types = op->getOperandTypes();
33
34 // We want at least one ranked tensor.
35 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
36
37 // No invalid operands (i.e., every operand is a ranked tensor or
38 // scalar-like).
39 bool noneInvalid = llvm::none_of(types, [](Type t) {
40 return !(isa<RankedTensorType>(t) || isScalarLike(t));
41 });
42
43 return anyRankedTensor && noneInvalid;
44}
45
46/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
47/// the result types and return a list of values such that, for each result type
48/// `t` and value `v` at the same index `idx`:
49/// 1. `v.getType() == t`
50/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
51/// such operand. Then`v == operand_first`.
52/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
53/// a. Static and dynamic dims extracted from the first operand of `op`.
54/// b. Elemental type equal to the elemental type of `t`.
55///
56/// This is sufficient because ElementwiseMappable guarantees that "The static
57/// types of all vector (resp. tensor) operands and results must have the same
58/// shape".
62 Location loc = op->getLoc();
63 ValueRange operands = op->getOperands();
64 TypeRange rankedTensorTypes = op->getResultTypes();
66 res.reserve(rankedTensorTypes.size());
67 for (Type t : rankedTensorTypes) {
68 // Try to find an operand with type matching the result tensor.
69 bool found = false;
70 for (Value v : operands) {
71 if (v.getType() == t) {
72 found = true;
73 res.push_back(v);
74 break;
75 }
76 }
77 if (found)
78 continue;
79
80 // Extract static / dynamic shape mix from the first operand.
81 res.push_back(tensor::EmptyOp::create(
82 b, loc, tensor::getMixedSizes(b, loc, operands.front()),
83 cast<RankedTensorType>(t).getElementType()));
84 }
85 return res;
86}
87
88namespace {
89struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
90 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
91 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
92 LogicalResult matchAndRewrite(Operation *op,
93 PatternRewriter &rewriter) const final {
95 return rewriter.notifyMatchFailure(
96 op, "requires elementwise op on ranked tensors");
97
98 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
99 auto rank = resTy.getRank();
100
101 // Maps: identity for tensors (rank > 0), scalar map for scalars.
102 AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
103 /*results=*/{}, rewriter.getContext());
104 AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
105
106 // Match phase.
107 SmallVector<bool> isScalarOperand;
108 isScalarOperand.reserve(op->getNumOperands());
109 for (Type ty : op->getOperandTypes()) {
110 if (isScalarLike(ty))
111 isScalarOperand.push_back(true);
112 else if (auto rt = dyn_cast<RankedTensorType>(ty))
113 isScalarOperand.push_back(false);
114 else
115 return rewriter.notifyMatchFailure(
116 op,
117 "unsupported operand type (expected scalar-like or ranked tensor)");
118 }
119
120 // Create indexing maps.
121 SmallVector<AffineMap> indexingMaps;
122 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
123
124 for (bool isScalar : isScalarOperand)
125 indexingMaps.push_back(isScalar ? scalarMap : idMap);
126
127 indexingMaps.append(op->getNumResults(), idMap);
128
129 SmallVector<utils::IteratorType> iteratorTypes(
130 rank, utils::IteratorType::parallel);
131 SmallVector<Value> outputs =
133 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
134 op, /*resultTensorTypes=*/op->getResultTypes(),
135 /*inputs=*/op->getOperands(),
136 /*outputs=*/outputs,
137 /*indexingMaps=*/indexingMaps,
138 /*iteratorTypes=*/iteratorTypes,
139 /*bodyBuilder=*/
140 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
141 SmallVector<Type> resultEltTys =
142 llvm::map_to_vector<6>(op->getResultTypes(), [](Type type) {
143 return cast<TensorType>(type).getElementType();
144 });
145 Operation *scalarOp =
146 builder.create(loc, op->getName().getIdentifier(),
147 regionArgs.take_front(op->getNumOperands()),
148 resultEltTys, op->getAttrs());
149 linalg::YieldOp::create(builder, loc, scalarOp->getResults());
150 });
151 return success();
152 }
153};
154} // namespace
155
158 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
159 patterns.getContext());
160}
161
162namespace {
163class ConvertElementwiseToLinalgPass
164 : public impl::ConvertElementwiseToLinalgPassBase<
165 ConvertElementwiseToLinalgPass> {
166 using impl::ConvertElementwiseToLinalgPassBase<
167 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
168
169 void runOnOperation() final {
170 auto *func = getOperation();
171 auto *context = &getContext();
172 ConversionTarget target(*context);
174
176 target.markUnknownOpDynamicallyLegal([](Operation *op) {
178 });
179
180 if (failed(applyPartialConversion(func, target, std::move(patterns))))
181 signalPassFailure();
182 }
183};
184} // namespace
return success()
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns