MLIR 22.0.0git
Specialize.cpp
Go to the documentation of this file.
1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28#define DEBUG_TYPE "linalg-specialization"
29
30#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
31 (rewriter.replaceOpWithNewOp<NEWOP>( \
32 genericOp, \
33 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
34 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
35 ValueRange{genericOp.getDpsInits()[0]}))
36
37#define REPLACE_UNARY_OP(NEWOP) \
38 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
39 ValueRange{genericOp.getDpsInputs()[0]}, \
40 ValueRange{genericOp.getDpsInits()[0]}))
41
42using namespace mlir;
43using namespace mlir::linalg;
44
45// Given a elementwise single binary linalg generic op, checks whether the
46// binary op accesses operands as swapped. e.g.
47// this differentiates between a linalg-generic body that contains:
48// ^bb0(%a: f32, %b: f32, %c : f32):
49// %0 = arith.subf %a, %b : f32
50// linalg.yield %0: f32
51// against:
52// ^bb0(%a: f32, %b: f32, %c : f32):
53// %0 = arith.subf %b, %a : f32
54// linalg.yield %0: f32
55// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
56static bool areBinOpsSwapped(GenericOp genericOp) {
57 Block *body = genericOp.getBody();
58 Operation *op = &body->front();
59 bool swapped = false;
60 if (op->getOpOperand(0).get() != body->getArgument(0)) {
61 swapped = true;
62 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
63 op->getOpOperand(1).get() == body->getArgument(0) &&
64 "binary op uses just one block arg");
65 }
66 return swapped;
67}
68
69//===----------------------------------------------------------------------===//
70// Specialize linalg generic to matmul variants.
71//===----------------------------------------------------------------------===//
72/// Identifies linalg.generic that is essentially named op of the form:
73// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
74//
75// It is possible that a linalg.generic may be implementing a matmul but not
76// in a straight-forward way e.g. below is matrix multiply over some slice
77// ```
78// %0 = linalg.generic {
79// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
80// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
81// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
82// iterator_types = ["parallel", "parallel", "parallel"]}
83// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
84// outs(%C : tensor<20x20x20xf32>) {
85// ^bb0(%a: f32, %b: f32, %c : f32):
86// %mul = arith.mulf %a, %b : f32
87// %add = arith.addf %mul, %c : f32
88// linalg.yield %add : f32
89// } -> tensor<20x20x20xf32>
90// ```
91// It is not possible to represent above as named op.
92// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
93// not the same as linalg.generic above.
94namespace {
95enum class IndexMatchResult {
96 Match = 0, // identity map.
97 Transposed, // transposed map.
98 Mismatch // none of the above.
99};
100
101// Checks whether the input Affine `map` contains two consecutive dims that
102// can be interpreted as accessing a 2D matrix. It is assumed that the row
103// column dimension are adjacent axis (in this order) and start at
104// `rowDimIdx` in the input map.
105//
106// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
107// whether the map of A is identity (match), transposed, or something
108// completely different (mis-match). Similar for B and C.
109static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
110 unsigned expectedPosOfRowDim,
111 unsigned expectedPosOfColDim) {
112 // Get the matrix multiply indices. They are past the batch indices.
113 auto exprOfRowDim = map.getResults()[rowDimIdx];
114 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
115
116 // They should be pure dimension ids.
117 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
118 exprOfColDim.getKind() != AffineExprKind::DimId)
119 return IndexMatchResult::Mismatch;
120
121 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
122 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
123
124 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
125 return IndexMatchResult::Match;
126
127 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
128 return IndexMatchResult::Transposed;
129
130 return IndexMatchResult::Mismatch;
131}
132
133// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
134// All the variants expressed as pseudo regular expression:
135// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
136// have same number of ins/out, so its easy to stamp different versions.
137template <typename NamedOpTy>
138static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
139 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
140 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
141 ValueRange{op.getDpsInits()[0]});
142 return namedOp;
143}
144
145// Converts linalg.generic to named linalg.*matmul* where possible.
146static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
147 GenericOp genericOp) {
148 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
149 return failure();
150
151 // Early exit if not projected permutations.
152 auto mapRange = genericOp.getIndexingMapsArray();
153 if (llvm::any_of(mapRange,
154 [](AffineMap m) { return !m.isProjectedPermutation(); }))
155 return failure();
156
157 // Linalg generic contraction can be across multiple axis e.g.
158 // ```
159 // linalg.generic
160 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
161 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
162 // affine_map<(m, n, k1, k2) -> (m, n)>],
163 // iterator_types = ["parallel", "parallel",
164 // "reduction", "reduction"]}
165 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
166 // outs(%C : tensor<10x40xf32>) {
167 // ^bb0(%a: f32, %b: f32, %c: f32):
168 // %1 = arith.mulf %a, %b : f32
169 // %2 = arith.addf %c, %1 : f32
170 // linalg.yield %2 : f32
171 // } -> tensor<10x40xf32>
172 // ```
173 // In above contraction, there are two reduction dimensions {k1, k2}
174 // and although a valid linalg contraction, it is not a named-op
175 // matrix multiply kind. Therefore, reject multi-dim reduction.
176 auto res = inferContractionDims(genericOp);
177 if (!succeeded(res))
178 return failure();
179 auto dims = *res;
180 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
181 return failure();
182
184 *genericOp.getBlock(), [](Operation *first, Operation *second) {
185 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
186 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
187 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
188 }))
189 return failure();
190
191 // Check rank of operands
192 auto indexingMaps = genericOp.getIndexingMapsArray();
193 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
194 return m.getResults().size() !=
195 dims.batch.size() + 2 /* any two of {m,n,k} */;
196 }))
197 return failure();
198
199 auto numOfBatchDims = dims.batch.size();
200 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
201 return failure();
202
203 if (numOfBatchDims) {
204 // Each operand in a linalg generic contraction could express different
205 // permutations for its batch dimension. But for named op it must be
206 // identity since separate maps are not specified.
207 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
208 for (unsigned i = 0; i < numOfBatchDims; ++i) {
209 auto expr = m.getResults()[i];
210 if (expr.getKind() != AffineExprKind::DimId ||
211 cast<AffineDimExpr>(expr).getPosition() != i)
212 return true;
213 }
214 return false;
215 }))
216 return failure();
217 }
218
219 auto a =
220 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
221 auto b =
222 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
223 auto c =
224 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
225
226 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
227 return failure();
228
229 if (c != IndexMatchResult::Match ||
230 (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
231 return failure();
232
233 /// Codegen the different matmul variants.
234 if (numOfBatchDims) {
235 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
236 }
237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238}
239
240/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
241/// with `dilations` and `strides`.
242template <typename ConvOpTy>
243static FailureOr<LinalgOp>
244specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
245 ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
246 SmallVector<Value> inputs = genericOp.getDpsInputs();
247 ValueRange outputs = genericOp.getDpsInits();
248 SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
249 ? TypeRange(ValueRange(outputs))
250 : TypeRange{};
251 LinalgOp namedOp;
252 // Ops with no dilations and no strides.
253 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
256 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
257 inputs, outputs);
258 } else {
259 Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
260 Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
261 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
262 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
263 }
264 return namedOp;
265}
266
267/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
268static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
269 GenericOp genericOp) {
270 SmallVector<int64_t> dilations, strides;
271#define CONV_OP_SPECIALIZER(ConvOpTy) \
272 if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
273 return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
274 strides); \
275 // -----------------------------
276 // Convolution ops.
277 // -----------------------------
278 CONV_OP_SPECIALIZER(linalg::Conv1DOp);
279 CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
280 CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
281 CONV_OP_SPECIALIZER(linalg::Conv2DOp);
282 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
283 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
284 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
285 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
286 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
287 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
288 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
289 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
290 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
291 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
292 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
293 CONV_OP_SPECIALIZER(linalg::Conv3DOp);
294 // -----------------------------
295 // Depthwise Convolution ops.
296 // -----------------------------
297 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
298 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
299 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
300 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
301 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
302 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
303 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
304 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
305 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
306 // -----------------------------
307 // Pooling ops.
308 // -----------------------------
309 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
310 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
311 CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
312 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
313 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
314#undef CONV_OP_SPECIALIZER
315 return failure();
316}
317
318} // namespace
319
320//===----------------------------------------------------------------------===//
321// Categorize linalg generic to named op where possible.
322//===----------------------------------------------------------------------===//
324 GenericOp genericOp) {
325 // Copy
326 if (isaCopyOpInterface(genericOp)) {
327 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
328 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
329 return namedOp;
330 }
331
332 // Fill
333 if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
334 // Always use the detected fill value, regardless of pattern
335 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
336 genericOp, *fillValue, genericOp.getDpsInits()[0]);
337 return namedOp;
338 }
339
340 // Broadcast
341 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
342 isaBroadcastOpInterface(genericOp);
343 if (equivalentToBroadcast) {
344 auto dims = *equivalentToBroadcast;
345 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
346 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
347 dims);
348 return namedOp;
349 }
350
351 // Transpose
352 std::optional<SmallVector<int64_t>> equivalentToTranspose =
353 isaTransposeOpInterface(genericOp);
354 if (equivalentToTranspose) {
355 auto permutation = *equivalentToTranspose;
356 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
357 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
358 permutation);
359 return namedOp;
360 }
361
362 // Elementwise Unary
363 if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
364 Operation *op = &genericOp.getBody()->front();
365 if (isa<math::ExpOp>(op)) {
366 LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
367 return namedOp;
368 }
369 }
370
371 // Elementwise Binary
372 if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
373 bool swap = areBinOpsSwapped(genericOp);
374 Operation *op = &genericOp.getBody()->front();
375 if (isa<arith::AddFOp>(op)) {
376 LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
377 return namedOp;
378 }
379 if (isa<arith::SubFOp>(op)) {
380 LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
381 return namedOp;
382 }
383 if (isa<arith::MulFOp>(op)) {
384 LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
385 return namedOp;
386 }
387 if (isa<arith::DivFOp>(op)) {
388 LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
389 return namedOp;
390 }
391 }
392
393 // Contraction - e.g. matmul
394 if (isaContractionOpInterface(genericOp)) {
395 return specializeLinalgContractions(rewriter, genericOp);
396 }
397
398 // Convolution - e.g. *conv/pooling*
399 if (isaConvolutionOpInterface(genericOp)) {
400 return specializeLinalgConvolutions(rewriter, genericOp);
401 }
402 return failure();
403}
404
405namespace {
406struct LinalgSpecializeGenericOpsPass
408 LinalgSpecializeGenericOpsPass> {
409
411 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
412 void runOnOperation() override;
413};
414} // namespace
415
416void LinalgSpecializeGenericOpsPass::runOnOperation() {
417 RewritePatternSet patterns(&getContext());
420
421 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
422 signalPassFailure();
423}
424
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Operation & front()
Definition Block.h:153
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:186
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:388
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
const FrozenRewritePatternSet & patterns