MLIR 23.0.0git
Specialize.cpp
Go to the documentation of this file.
1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29#define DEBUG_TYPE "linalg-specialization"
30
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
33 genericOp, \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
37
38#define REPLACE_UNARY_OP(NEWOP) \
39 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
40 ValueRange{genericOp.getDpsInputs()[0]}, \
41 ValueRange{genericOp.getDpsInits()[0]}))
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46// Given a elementwise single binary linalg generic op, checks whether the
47// binary op accesses operands as swapped. e.g.
48// this differentiates between a linalg-generic body that contains:
49// ^bb0(%a: f32, %b: f32, %c : f32):
50// %0 = arith.subf %a, %b : f32
51// linalg.yield %0: f32
52// against:
53// ^bb0(%a: f32, %b: f32, %c : f32):
54// %0 = arith.subf %b, %a : f32
55// linalg.yield %0: f32
56// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
57static bool areBinOpsSwapped(GenericOp genericOp) {
58 Block *body = genericOp.getBody();
59 Operation *op = &body->front();
60 bool swapped = false;
61 if (op->getOpOperand(0).get() != body->getArgument(0)) {
62 swapped = true;
63 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
64 op->getOpOperand(1).get() == body->getArgument(0) &&
65 "binary op uses just one block arg");
66 }
67 return swapped;
68}
69
70//===----------------------------------------------------------------------===//
71// Specialize linalg generic to matmul variants.
72//===----------------------------------------------------------------------===//
73/// Identifies linalg.generic that is essentially named op of the form:
74// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
75//
76// It is possible that a linalg.generic may be implementing a matmul but not
77// in a straight-forward way e.g. below is matrix multiply over some slice
78// ```
79// %0 = linalg.generic {
80// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
81// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
82// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
83// iterator_types = ["parallel", "parallel", "parallel"]}
84// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
85// outs(%C : tensor<20x20x20xf32>) {
86// ^bb0(%a: f32, %b: f32, %c : f32):
87// %mul = arith.mulf %a, %b : f32
88// %add = arith.addf %mul, %c : f32
89// linalg.yield %add : f32
90// } -> tensor<20x20x20xf32>
91// ```
92// It is not possible to represent above as named op.
93// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
94// not the same as linalg.generic above.
95namespace {
96enum class IndexMatchResult {
97 Match = 0, // identity map.
98 Transposed, // transposed map.
99 Mismatch // none of the above.
100};
101
102// Checks whether the input Affine `map` contains two consecutive dims that
103// can be interpreted as accessing a 2D matrix. It is assumed that the row
104// column dimension are adjacent axis (in this order) and start at
105// `rowDimIdx` in the input map.
106//
107// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
108// whether the map of A is identity (match), transposed, or something
109// completely different (mis-match). Similar for B and C.
110static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
111 unsigned expectedPosOfRowDim,
112 unsigned expectedPosOfColDim) {
113 // Get the matrix multiply indices. They are past the batch indices.
114 auto exprOfRowDim = map.getResults()[rowDimIdx];
115 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
116
117 // They should be pure dimension ids.
118 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
119 exprOfColDim.getKind() != AffineExprKind::DimId)
120 return IndexMatchResult::Mismatch;
121
122 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
123 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
124
125 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
126 return IndexMatchResult::Match;
127
128 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
129 return IndexMatchResult::Transposed;
130
131 return IndexMatchResult::Mismatch;
132}
133
134// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
135// All the variants expressed as pseudo regular expression:
136// `linalg.{batch_}?matmul` have same number of ins/out, so it's easy to
137// stamp different versions.
138// `castTy` is an optional type function that indicates whether (and which) cast
139// attribute is needed for the named matmul op variant.
140template <typename NamedOpTy>
141static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
142 std::optional<TypeFn> castTy,
143 ArrayRef<AffineMap> indexingMaps) {
145 // Only explicitly specify the cast attribute for unsigned cast; signed is
146 // the default for linalg.matmul/linalg.batch_matmul.
147 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
148 auto castAttr = rewriter.getNamedAttr(
149 "cast", TypeFnAttr::get(rewriter.getContext(), *castTy));
150 attributes.push_back(castAttr);
151 }
152
153 // Set the original generic's maps to preserve operand indexing semantics like
154 // transposition.
155 SmallVector<Attribute, 3> indexingMapsAttrVal =
156 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
157 return AffineMapAttr::get(map);
158 });
159 auto indexingMapsAttr = rewriter.getNamedAttr(
160 "indexing_maps", rewriter.getArrayAttr(indexingMapsAttrVal));
161 attributes.push_back(indexingMapsAttr);
162
163 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
164 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
165 ValueRange{op.getDpsInits()[0]}, attributes);
166
167 return namedOp;
168}
169
170// Returns the cast type to use for a matmul-like named op. If the generic
171// contains casts that cannot be represented (e.g. output casts or mixed
172// signedness), return std::nullopt.
173static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
174 bool foundCastForMatmulOutput = false;
175 SmallVector<TypeFn> castTyFns;
176 genericOp.getBody()->walk([&](CastOpInterface castOp) {
177 // Collect forward slice of the cast op to check if it is for the matmul
178 // output.
179 SetVector<Operation *> forwardSlice;
180 getForwardSlice(castOp, &forwardSlice);
181
182 // If there is no multiplication op in the forward slice, then this cast
183 // op is for the matmul output. Cast ops on matmul output cannot be
184 // expressed by the matmul op variant.
185 if (!llvm::any_of(forwardSlice, [](Operation *op) {
186 // We check explicitly for these multiplication ops in
187 // `specializeLinalgContractions()` to infer matmul-like ops.
188 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
189 })) {
190 foundCastForMatmulOutput = true;
191 return WalkResult::interrupt();
192 }
193
194 // Determine the cast type.
195 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
196 castTyFns.push_back(TypeFn::cast_unsigned);
197 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
198 castTyFns.push_back(TypeFn::cast_signed);
199
200 return WalkResult::advance();
201 });
202
203 if (foundCastForMatmulOutput)
204 return std::nullopt;
205
206 if (!castTyFns.empty()) {
207 // If there were multiple different cast types found, then we can't express
208 // them using matmul-like ops. They only allow a single cast type for all
209 // inputs.
210 if (!llvm::all_equal(castTyFns))
211 return std::nullopt;
212 return castTyFns.front();
213 }
214
215 // Default to signed cast for matmul-like ops.
216 return TypeFn::cast_signed;
217}
218
219// Converts linalg.generic to named linalg.*matmul* where possible.
220static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
221 GenericOp genericOp,
222 bool emitCategoryOp) {
223 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
224 return failure();
225
226 // Early exit if not projected permutations.
227 auto mapRange = genericOp.getIndexingMapsArray();
228 if (llvm::any_of(mapRange,
229 [](AffineMap m) { return !m.isProjectedPermutation(); }))
230 return failure();
231
232 // Only mul+add contraction is supported.
233 // Currently, there is no way to control the contraction body type in named
234 // and category ops which all default to mul+add only.
236 *genericOp.getBlock(), [](Operation *first, Operation *second) {
237 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
238 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
239 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
240 }))
241 return failure();
242
243 // Determine the cast type for the named matmul op, or bail out if casts
244 // cannot be represented by the named op.
245 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
246 if (!castTy)
247 return rewriter.notifyMatchFailure(
248 genericOp, "contains invalid cast ops for the named matmul op");
249
250 // In case of category op, wider range of variants is supported.
251 if (emitCategoryOp)
252 return replaceWithMatmulVariant<ContractOp>(
253 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
254
255 // Further checks for named variants.
256 //
257 // Linalg generic contraction can be across multiple axis e.g.
258 // ```
259 // linalg.generic
260 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
261 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
262 // affine_map<(m, n, k1, k2) -> (m, n)>],
263 // iterator_types = ["parallel", "parallel",
264 // "reduction", "reduction"]}
265 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
266 // outs(%C : tensor<10x40xf32>) {
267 // ^bb0(%a: f32, %b: f32, %c: f32):
268 // %1 = arith.mulf %a, %b : f32
269 // %2 = arith.addf %c, %1 : f32
270 // linalg.yield %2 : f32
271 // } -> tensor<10x40xf32>
272 // ```
273 // In above contraction, there are two reduction dimensions {k1, k2}
274 // and although a valid linalg contraction, it is not a named-op
275 // matrix multiply kind. Therefore, reject multi-dim reduction.
276 auto res = inferContractionDims(genericOp);
277 if (!succeeded(res))
278 return failure();
279 auto dims = *res;
280 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
281 return failure();
282
283 // Check rank of operands
284 auto indexingMaps = genericOp.getIndexingMapsArray();
285 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
286 return m.getResults().size() !=
287 dims.batch.size() + 2 /* any two of {m,n,k} */;
288 }))
289 return failure();
290
291 auto numOfBatchDims = dims.batch.size();
292 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
293 return failure();
294
295 if (numOfBatchDims) {
296 // Each operand in a linalg generic contraction could express different
297 // permutations for its batch dimension. But for named op it must be
298 // identity since separate maps are not specified.
299 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
300 for (unsigned i = 0; i < numOfBatchDims; ++i) {
301 auto expr = m.getResults()[i];
302 if (expr.getKind() != AffineExprKind::DimId ||
303 cast<AffineDimExpr>(expr).getPosition() != i)
304 return true;
305 }
306 return false;
307 }))
308 return failure();
309 }
310
311 auto a =
312 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
313 auto b =
314 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
315 auto c =
316 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
317
318 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
319 return failure();
320
321 // Build indexing maps for the named op in its canonical dimension ordering
322 auto *ctx = genericOp.getContext();
323 unsigned numLoopDims = numOfBatchDims + 3;
324 unsigned mIdx = numOfBatchDims;
325 unsigned nIdx = mIdx + 1;
326 unsigned kIdx = mIdx + 2;
327
328 // TODO: add support for indexing_maps with broadcasts.
329 auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
330 SmallVector<unsigned> tensorDims;
331 for (unsigned i = 0; i < numOfBatchDims; ++i)
332 tensorDims.push_back(i);
333 if (match == IndexMatchResult::Transposed)
334 llvm::append_values(tensorDims, colIdx, rowIdx);
335 else
336 llvm::append_values(tensorDims, rowIdx, colIdx);
337 return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
338 };
339
340 auto mapA = makeMap(a, mIdx, kIdx);
341 auto mapB = makeMap(b, kIdx, nIdx);
342 auto mapC = makeMap(c, mIdx, nIdx);
343
344 SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
345
346 // Codegen the different matmul variants.
347 if (numOfBatchDims) {
348 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
349 namedOpMaps);
350 }
351 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
352 namedOpMaps);
353}
354
355/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
356/// with `dilations` and `strides`.
357template <typename ConvOpTy>
358static FailureOr<LinalgOp>
359specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
360 ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
361 SmallVector<Value> inputs = genericOp.getDpsInputs();
362 ValueRange outputs = genericOp.getDpsInits();
363 SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
364 ? TypeRange(ValueRange(outputs))
365 : TypeRange{};
366 LinalgOp namedOp;
367 // Ops with no dilations and no strides.
368 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
369 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
370 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
371 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
372 inputs, outputs);
373 } else {
374 Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
375 Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
376 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
377 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
378 }
379 return namedOp;
380}
381
382/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
383static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
384 GenericOp genericOp) {
385#define CONV_OP_SPECIALIZER(ConvOpTy) \
386 if (std::optional<DilationsAndStrides> convParams = \
387 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
388 return specializeToConvOp<ConvOpTy>( \
389 rewriter, genericOp, convParams->dilations, convParams->strides); \
390 // -----------------------------
391 // Convolution ops.
392 // -----------------------------
393 CONV_OP_SPECIALIZER(linalg::Conv1DOp);
394 CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
395 CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
396 CONV_OP_SPECIALIZER(linalg::Conv2DOp);
397 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
398 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
399 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
400 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
401 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
402 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
403 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
404 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
405 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
406 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
407 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
408 CONV_OP_SPECIALIZER(linalg::Conv3DOp);
409 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfOp);
410 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfQOp);
411 CONV_OP_SPECIALIZER(linalg::Conv3DNcdhwFcdhwOp);
412 // -----------------------------
413 // Depthwise Convolution ops.
414 // -----------------------------
415 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
416 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
417 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
418 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
419 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
420 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
421 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
422 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
423 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcOp);
424 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNcdhwCdhwOp);
425 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
426 // -----------------------------
427 // Pooling ops.
428 // -----------------------------
429 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
430 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
431 CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
432 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
433 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
434 CONV_OP_SPECIALIZER(linalg::PoolingNchwSumOp);
435 CONV_OP_SPECIALIZER(linalg::PoolingNchwMaxOp);
436 CONV_OP_SPECIALIZER(linalg::PoolingNwcSumOp);
437 CONV_OP_SPECIALIZER(linalg::PoolingNcwSumOp);
438 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxOp);
439 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxUnsignedOp);
440 CONV_OP_SPECIALIZER(linalg::PoolingNcwMaxOp);
441 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinOp);
442 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinUnsignedOp);
443 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcSumOp);
444 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMaxOp);
445 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMinOp);
446#undef CONV_OP_SPECIALIZER
447 return failure();
448}
449
450} // namespace
451
452//===----------------------------------------------------------------------===//
453// Categorize linalg generic to named op where possible.
454//===----------------------------------------------------------------------===//
456 RewriterBase &rewriter, GenericOp genericOp,
458 // Contraction - e.g. matmul
459 if (isaContractionOpInterface(genericOp)) {
460 return specializeLinalgContractions(rewriter, genericOp,
461 options.emitCategoryOps);
462 }
463
464 // Early exit in case of category specialization.
465 // TODO: Remove when matches for other ops account for both named and
466 // category.
467 if (options.emitCategoryOps)
468 return rewriter.notifyMatchFailure(
469 genericOp, "no matching category op specialization");
470
471 // Copy
472 if (isaCopyOpInterface(genericOp)) {
473 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
474 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
475 return namedOp;
476 }
477
478 // Fill
479 if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
480 // Always use the detected fill value, regardless of pattern
481 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
482 genericOp, *fillValue, genericOp.getDpsInits()[0]);
483 return namedOp;
484 }
485
486 // Broadcast
487 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
488 isaBroadcastOpInterface(genericOp);
489 if (equivalentToBroadcast) {
490 auto dims = *equivalentToBroadcast;
491 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
492 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
493 dims);
494 return namedOp;
495 }
496
497 // Transpose
498 std::optional<SmallVector<int64_t>> equivalentToTranspose =
499 isaTransposeOpInterface(genericOp);
500 if (equivalentToTranspose) {
501 auto permutation = *equivalentToTranspose;
502 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
503 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
504 permutation);
505 return namedOp;
506 }
507
508 // Elementwise Unary
509 if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
510 Operation *op = &genericOp.getBody()->front();
511 if (isa<math::ExpOp>(op)) {
512 LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
513 return namedOp;
514 }
515 }
516
517 // Elementwise Binary
518 if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
519 bool swap = areBinOpsSwapped(genericOp);
520 Operation *op = &genericOp.getBody()->front();
521 if (isa<arith::AddFOp>(op)) {
522 LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
523 return namedOp;
524 }
525 if (isa<arith::SubFOp>(op)) {
526 LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
527 return namedOp;
528 }
529 if (isa<arith::MulFOp>(op)) {
530 LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
531 return namedOp;
532 }
533 if (isa<arith::DivFOp>(op)) {
534 LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
535 return namedOp;
536 }
537 }
538
539 // Convolution - e.g. *conv/pooling*
540 if (isaConvolutionOpInterface(genericOp))
541 return specializeLinalgConvolutions(rewriter, genericOp);
542
543 return rewriter.notifyMatchFailure(genericOp,
544 "no matching named op specialization");
545}
546
547namespace {
548struct LinalgSpecializeGenericOpsPass
550 LinalgSpecializeGenericOpsPass> {
551
553 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
554 void runOnOperation() override;
555};
556} // namespace
557
558void LinalgSpecializeGenericOpsPass::runOnOperation() {
559 RewritePatternSet patterns(&getContext());
562
563 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
564 signalPassFailure();
565}
566
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:190
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:417
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.