MLIR 23.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "llvm/ADT/SmallVectorExtras.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
19#include "mlir/Dialect/Linalg/Passes.h.inc"
20} // namespace mlir
21
22using namespace mlir;
23
24static inline bool isScalarLike(Type t) {
25 return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
26}
27
30 return false;
31
32 auto types = op->getOperandTypes();
33
34 // We want at least one ranked tensor.
35 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
36
37 // No invalid operands (i.e., every operand is a ranked tensor or
38 // scalar-like).
39 bool noneInvalid = llvm::none_of(types, [](Type t) {
40 return !(isa<RankedTensorType>(t) || isScalarLike(t));
41 });
43 return anyRankedTensor && noneInvalid;
46/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
47/// the result types and return a list of values such that, for each result type
48/// `t` and value `v` at the same index `idx`:
49/// 1. `v.getType() == t`
50/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
51/// such operand. Then`v == operand_first`.
52/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
53/// a. Static and dynamic dims extracted from the first operand of `op`.
54/// b. Elemental type equal to the elemental type of `t`.
55///
56/// This is sufficient because ElementwiseMappable guarantees that "The static
57/// types of all vector (resp. tensor) operands and results must have the same
58/// shape".
62 Location loc = op->getLoc();
63 ValueRange operands = op->getOperands();
64 TypeRange rankedTensorTypes = op->getResultTypes();
66 res.reserve(rankedTensorTypes.size());
67 for (Type t : rankedTensorTypes) {
68 // Try to find an operand with type matching the result tensor.
69 bool found = false;
70 for (Value v : operands) {
71 if (v.getType() == t) {
72 found = true;
73 res.push_back(v);
74 break;
75 }
76 }
77 if (found)
78 continue;
79
80 // Extract static / dynamic shape mix from the first operand.
81 res.push_back(tensor::EmptyOp::create(
82 b, loc, tensor::getMixedSizes(b, loc, operands.front()),
83 cast<RankedTensorType>(t).getElementType()));
84 }
85 return res;
87
88namespace {
89struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
90 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
91 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
92 LogicalResult matchAndRewrite(Operation *op,
93 PatternRewriter &rewriter) const final {
95 return rewriter.notifyMatchFailure(
96 op, "requires elementwise op on ranked tensors");
97
98 auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
99 auto rank = resTy.getRank();
100
101 // Maps: identity for tensors (rank > 0), scalar map for scalars.
102 AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
103 /*results=*/{}, rewriter.getContext());
104 AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
105
106 // Match phase.
107 SmallVector<bool> isScalarOperand;
108 isScalarOperand.reserve(op->getNumOperands());
109 for (Type ty : op->getOperandTypes()) {
110 if (isScalarLike(ty))
111 isScalarOperand.push_back(true);
112 else if (auto rt = dyn_cast<RankedTensorType>(ty))
113 isScalarOperand.push_back(false);
114 else
115 return rewriter.notifyMatchFailure(
116 op,
117 "unsupported operand type (expected scalar-like or ranked tensor)");
118 }
119
120 // Create indexing maps.
121 SmallVector<AffineMap> indexingMaps;
122 indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
123
124 for (bool isScalar : isScalarOperand)
125 indexingMaps.push_back(isScalar ? scalarMap : idMap);
126
127 indexingMaps.append(op->getNumResults(), idMap);
128
130 rank, utils::IteratorType::parallel);
131 SmallVector<Value> outputs =
133 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
134 op, /*resultTensorTypes=*/op->getResultTypes(),
135 /*inputs=*/op->getOperands(),
136 /*outputs=*/outputs,
137 /*indexingMaps=*/indexingMaps,
138 /*iteratorTypes=*/iteratorTypes,
139 /*bodyBuilder=*/
140 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
141 SmallVector<Type> resultEltTys =
142 llvm::map_to_vector<6>(op->getResultTypes(), [](Type type) {
143 return cast<TensorType>(type).getElementType();
144 });
145 Operation *scalarOp =
146 builder.create(loc, op->getName().getIdentifier(),
147 regionArgs.take_front(op->getNumOperands()),
148 resultEltTys, op->getAttrs());
149 linalg::YieldOp::create(builder, loc, scalarOp->getResults());
150 });
151 return success();
152 }
153};
154} // namespace
155
158 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
159 patterns.getContext());
160}
161
162namespace {
163class ConvertElementwiseToLinalgPass
165 ConvertElementwiseToLinalgPass> {
167 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
168
169 void runOnOperation() final {
170 auto *func = getOperation();
171 auto *context = &getContext();
172 ConversionTarget target(*context);
174
176 target.markUnknownOpDynamicallyLegal([](Operation *op) {
178 });
179
180 if (failed(applyPartialConversion(func, target, std::move(patterns))))
181 signalPassFailure();
182 }
183};
184} // namespace
return success()
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns