MLIR 23.0.0git
Specialize.cpp
Go to the documentation of this file.
1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29#define DEBUG_TYPE "linalg-specialization"
30
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
33 genericOp, \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
37
38using namespace mlir;
39using namespace mlir::linalg;
40
41//===----------------------------------------------------------------------===//
42// Specialize linalg generic to elementwise ops.
43//===----------------------------------------------------------------------===//
44
45// Given a elementwise single binary linalg generic op, checks whether the
46// binary op accesses operands as swapped. e.g.
47// this differentiates between a linalg-generic body that contains:
48// ^bb0(%a: f32, %b: f32, %c : f32):
49// %0 = arith.subf %a, %b : f32
50// linalg.yield %0: f32
51// against:
52// ^bb0(%a: f32, %b: f32, %c : f32):
53// %0 = arith.subf %b, %a : f32
54// linalg.yield %0: f32
55// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
56static bool areBinOpsSwapped(GenericOp genericOp) {
57 Block *body = genericOp.getBody();
58 Operation *op = &body->front();
59 bool swapped = false;
60 if (op->getOpOperand(0).get() != body->getArgument(0)) {
61 swapped = true;
62 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
63 op->getOpOperand(1).get() == body->getArgument(0) &&
64 "binary op uses just one block arg");
65 }
66 return swapped;
67}
68
69// Attempt to specialize linalg.generic to named elementwise ops or
70// linalg.elementwise.
71//
72// Example:
73// %0 = linalg.generic {
74// indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
75// affine_map<(d0, d1) -> (d0, d1)>],
76// iterator_types = ["parallel", "parallel"]
77// } ins(%In : tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
78// ^bb0(%in: f32, %out: f32):
79// %1 = math.exp %in : f32
80// linalg.yield %1 : f32
81// } -> tensor<?x?xf32>
82//
83// is specialized to either
84// linalg.exp ins(...) outs(...) -> ...
85// or
86// linalg.elementwise kind=#linalg.elementwise_kind<exp> ...
87//
88// Only the category op can carry non-identity indexing maps; these are
89// transferred verbatim from the `genericOp`.
90static FailureOr<LinalgOp>
91specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
92 bool emitCategoryOp) {
93 bool hasNonIdentityMaps =
94 !llvm::all_of(genericOp.getIndexingMapsArray(),
95 [](AffineMap map) { return map.isIdentity(); });
96
97 // Early exit: Named ops cannot carry user-defined maps.
98 if (hasNonIdentityMaps && !emitCategoryOp)
99 return rewriter.notifyMatchFailure(
100 genericOp,
101 "non-identity indexing maps prevent specialization to named op");
102
103 // Helper to dispatch between named op and `linalg.elementwise`.
104 // Lambdas with explicit template parameter list are a C++20 feature, hence
105 // the dummy op object.
106 auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
107 LinalgOp newOp;
108 if (!emitCategoryOp)
109 newOp = decltype(namedOp)::create(
110 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
111 genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
112 else
113 newOp = ElementwiseOp::create(
114 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
115 genericOp.getDpsInits(),
116 ElementwiseKindAttr::get(rewriter.getContext(), kind),
117 genericOp.getIndexingMaps());
118
119 rewriter.replaceOp(genericOp, newOp);
120 return newOp;
121 };
122
123 // Inspect body operation to determine named op or elementwise kind.
124 Operation *op = &genericOp.getBody()->front();
125
126 if (isa<math::ExpOp>(op))
127 return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
128 if (isa<math::LogOp>(op))
129 return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
130 if (isa<math::AbsFOp>(op))
131 return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
132 if (isa<math::CeilOp>(op))
133 return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
134 if (isa<math::FloorOp>(op))
135 return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
136 if (isa<arith::NegFOp>(op))
137 return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
138 if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
139 if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
140 divOp.getLhs().getDefiningOp()))
141 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
142 return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
143 }
144 if (isa<math::RoundOp>(op))
145 return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
146 if (isa<math::SqrtOp>(op))
147 return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
148 if (isa<math::RsqrtOp>(op))
149 return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
150 if (auto mulOp = dyn_cast<arith::MulFOp>(op);
151 mulOp && mulOp.getLhs() == mulOp.getRhs())
152 return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
153 if (isa<math::TanhOp>(op))
154 return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
155 if (isa<math::ErfOp>(op))
156 return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
157
158 return failure();
159}
160
161//===----------------------------------------------------------------------===//
162// Specialize linalg generic to matmul variants.
163//===----------------------------------------------------------------------===//
164/// Identifies linalg.generic that is essentially named op of the form:
165// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
166//
167// It is possible that a linalg.generic may be implementing a matmul but not
168// in a straight-forward way e.g. below is matrix multiply over some slice
169// ```
170// %0 = linalg.generic {
171// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
172// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
173// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
174// iterator_types = ["parallel", "parallel", "parallel"]}
175// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
176// outs(%C : tensor<20x20x20xf32>) {
177// ^bb0(%a: f32, %b: f32, %c : f32):
178// %mul = arith.mulf %a, %b : f32
179// %add = arith.addf %mul, %c : f32
180// linalg.yield %add : f32
181// } -> tensor<20x20x20xf32>
182// ```
183// It is not possible to represent above as named op.
184// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
185// not the same as linalg.generic above.
186namespace {
187enum class IndexMatchResult {
188 Match = 0, // identity map.
189 Transposed, // transposed map.
190 Mismatch // none of the above.
191};
192
193// Checks whether the input Affine `map` contains two consecutive dims that
194// can be interpreted as accessing a 2D matrix. It is assumed that the row
195// column dimension are adjacent axis (in this order) and start at
196// `rowDimIdx` in the input map.
197//
198// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
199// whether the map of A is identity (match), transposed, or something
200// completely different (mis-match). Similar for B and C.
201static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
202 unsigned expectedPosOfRowDim,
203 unsigned expectedPosOfColDim) {
204 // Get the matrix multiply indices. They are past the batch indices.
205 auto exprOfRowDim = map.getResults()[rowDimIdx];
206 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
207
208 // They should be pure dimension ids.
209 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
210 exprOfColDim.getKind() != AffineExprKind::DimId)
211 return IndexMatchResult::Mismatch;
212
213 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
214 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
215
216 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
217 return IndexMatchResult::Match;
218
219 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
220 return IndexMatchResult::Transposed;
221
222 return IndexMatchResult::Mismatch;
223}
224
225// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
226// All the variants expressed as pseudo regular expression:
227// `linalg.{batch_}?matmul` have same number of ins/out, so it's easy to
228// stamp different versions.
229// `castTy` is an optional type function that indicates whether (and which) cast
230// attribute is needed for the named matmul op variant.
231template <typename NamedOpTy>
232static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
233 std::optional<TypeFn> castTy,
234 ArrayRef<AffineMap> indexingMaps) {
236 // Only explicitly specify the cast attribute for unsigned cast; signed is
237 // the default for linalg.matmul/linalg.batch_matmul.
238 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
239 auto castAttr = rewriter.getNamedAttr(
240 "cast", TypeFnAttr::get(rewriter.getContext(), *castTy));
241 attributes.push_back(castAttr);
242 }
243
244 // Set the original generic's maps to preserve operand indexing semantics like
245 // transposition.
246 SmallVector<Attribute, 3> indexingMapsAttrVal =
247 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
248 return AffineMapAttr::get(map);
249 });
250 auto indexingMapsAttr = rewriter.getNamedAttr(
251 "indexing_maps", rewriter.getArrayAttr(indexingMapsAttrVal));
252 attributes.push_back(indexingMapsAttr);
253
254 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
255 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
256 ValueRange{op.getDpsInits()[0]}, attributes);
257
258 return namedOp;
259}
260
261// Returns the cast type to use for a matmul-like named op. If the generic
262// contains casts that cannot be represented (e.g. output casts or mixed
263// signedness), return std::nullopt.
264static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
265 bool foundCastForMatmulOutput = false;
266 SmallVector<TypeFn> castTyFns;
267 genericOp.getBody()->walk([&](CastOpInterface castOp) {
268 // Collect forward slice of the cast op to check if it is for the matmul
269 // output.
270 SetVector<Operation *> forwardSlice;
271 getForwardSlice(castOp, &forwardSlice);
272
273 // If there is no multiplication op in the forward slice, then this cast
274 // op is for the matmul output. Cast ops on matmul output cannot be
275 // expressed by the matmul op variant.
276 if (!llvm::any_of(forwardSlice, [](Operation *op) {
277 // We check explicitly for these multiplication ops in
278 // `specializeLinalgContractions()` to infer matmul-like ops.
279 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
280 })) {
281 foundCastForMatmulOutput = true;
282 return WalkResult::interrupt();
283 }
284
285 // Determine the cast type.
286 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
287 castTyFns.push_back(TypeFn::cast_unsigned);
288 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
289 castTyFns.push_back(TypeFn::cast_signed);
290
291 return WalkResult::advance();
292 });
293
294 if (foundCastForMatmulOutput)
295 return std::nullopt;
296
297 if (!castTyFns.empty()) {
298 // If there were multiple different cast types found, then we can't express
299 // them using matmul-like ops. They only allow a single cast type for all
300 // inputs.
301 if (!llvm::all_equal(castTyFns))
302 return std::nullopt;
303 return castTyFns.front();
304 }
305
306 // Default to signed cast for matmul-like ops.
307 return TypeFn::cast_signed;
308}
309
310static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
311 GenericOp genericOp,
312 std::optional<TypeFn> castTy,
313 ContractionDimensions &dims) {
314 // Should all be rank 4 and dim 6
315 auto indexingMaps = genericOp.getIndexingMapsArray();
316 if (llvm::any_of(indexingMaps, [](AffineMap m) {
317 return m.getResults().size() != 4 || m.getNumDims() != 6;
318 }))
319 return failure();
320
321 auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
322 auto aInner = matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
323
324 auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
325 auto bInner = matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
326
327 auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
328 auto cInner = matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
329
330 if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
331 return failure();
332 if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
333 return failure();
334
335 SmallVector<AffineMap> namedOpMaps = {indexingMaps[0], indexingMaps[1],
336 indexingMaps[2]};
337
338 return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
339 namedOpMaps);
340}
341
342// Converts linalg.generic to named linalg.*matmul* where possible.
343static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
344 GenericOp genericOp,
345 bool emitCategoryOp) {
346 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
347 return failure();
348
349 // Early exit if not projected permutations.
350 auto mapRange = genericOp.getIndexingMapsArray();
351 if (llvm::any_of(mapRange,
352 [](AffineMap m) { return !m.isProjectedPermutation(); }))
353 return failure();
354
355 // Only mul+add contraction is supported.
356 // Currently, there is no way to control the contraction body type in named
357 // and category ops which all default to mul+add only.
359 *genericOp.getBlock(), [](Operation *first, Operation *second) {
360 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
361 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
362 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
363 }))
364 return failure();
365
366 // Determine the cast type for the named matmul op, or bail out if casts
367 // cannot be represented by the named op.
368 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
369 if (!castTy)
370 return rewriter.notifyMatchFailure(
371 genericOp, "contains invalid cast ops for the named matmul op");
372
373 // In case of category op, wider range of variants is supported.
374 if (emitCategoryOp)
375 return replaceWithMatmulVariant<ContractOp>(
376 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
377
378 // Further checks for named variants.
379 //
380 // Linalg generic contraction can be across multiple axis e.g.
381 // ```
382 // linalg.generic
383 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
384 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
385 // affine_map<(m, n, k1, k2) -> (m, n)>],
386 // iterator_types = ["parallel", "parallel",
387 // "reduction", "reduction"]}
388 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
389 // outs(%C : tensor<10x40xf32>) {
390 // ^bb0(%a: f32, %b: f32, %c: f32):
391 // %1 = arith.mulf %a, %b : f32
392 // %2 = arith.addf %c, %1 : f32
393 // linalg.yield %2 : f32
394 // } -> tensor<10x40xf32>
395 // ```
396 // In above contraction, there are two reduction dimensions {k1, k2}
397 // and although a valid linalg contraction, it is not a named-op
398 // matrix multiply kind. Therefore, reject multi-dim reduction.
399 auto res = inferContractionDims(genericOp);
400 if (!succeeded(res))
401 return failure();
402 auto dims = *res;
403 if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
404 return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
405 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
406 return failure();
407
408 // Check rank of operands
409 auto indexingMaps = genericOp.getIndexingMapsArray();
410 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
411 return m.getResults().size() !=
412 dims.batch.size() + 2 /* any two of {m,n,k} */;
413 }))
414 return failure();
415
416 auto numOfBatchDims = dims.batch.size();
417 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
418 return failure();
419
420 if (numOfBatchDims) {
421 // Each operand in a linalg generic contraction could express different
422 // permutations for its batch dimension. But for named op it must be
423 // identity since separate maps are not specified.
424 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
425 for (unsigned i = 0; i < numOfBatchDims; ++i) {
426 auto expr = m.getResults()[i];
427 if (expr.getKind() != AffineExprKind::DimId ||
428 cast<AffineDimExpr>(expr).getPosition() != i)
429 return true;
430 }
431 return false;
432 }))
433 return failure();
434 }
435
436 auto a =
437 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
438 auto b =
439 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
440 auto c =
441 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
442
443 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
444 return failure();
445
446 // Build indexing maps for the named op in its canonical dimension ordering
447 auto *ctx = genericOp.getContext();
448 unsigned numLoopDims = numOfBatchDims + 3;
449 unsigned mIdx = numOfBatchDims;
450 unsigned nIdx = mIdx + 1;
451 unsigned kIdx = mIdx + 2;
452
453 // TODO: add support for indexing_maps with broadcasts.
454 auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
455 SmallVector<unsigned> tensorDims;
456 for (unsigned i = 0; i < numOfBatchDims; ++i)
457 tensorDims.push_back(i);
458 if (match == IndexMatchResult::Transposed)
459 llvm::append_values(tensorDims, colIdx, rowIdx);
460 else
461 llvm::append_values(tensorDims, rowIdx, colIdx);
462 return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
463 };
464
465 auto mapA = makeMap(a, mIdx, kIdx);
466 auto mapB = makeMap(b, kIdx, nIdx);
467 auto mapC = makeMap(c, mIdx, nIdx);
468
469 SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
470
471 // Codegen the different matmul variants.
472 if (numOfBatchDims) {
473 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
474 namedOpMaps);
475 }
476 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
477 namedOpMaps);
478}
479
480/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
481/// with `dilations` and `strides`.
482template <typename ConvOpTy>
483static FailureOr<LinalgOp>
484specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
485 ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
486 SmallVector<Value> inputs = genericOp.getDpsInputs();
487 ValueRange outputs = genericOp.getDpsInits();
488 SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
489 ? TypeRange(ValueRange(outputs))
490 : TypeRange{};
491 LinalgOp namedOp;
492 // Ops with no dilations and no strides.
493 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
494 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
495 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
496 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
497 inputs, outputs);
498 } else {
499 Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
500 Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
501 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
502 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
503 }
504 return namedOp;
505}
506
507/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
508static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
509 GenericOp genericOp) {
510#define CONV_OP_SPECIALIZER(ConvOpTy) \
511 if (std::optional<DilationsAndStrides> convParams = \
512 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
513 return specializeToConvOp<ConvOpTy>( \
514 rewriter, genericOp, convParams->dilations, convParams->strides); \
515 // -----------------------------
516 // Convolution ops.
517 // -----------------------------
518 CONV_OP_SPECIALIZER(linalg::Conv1DOp);
519 CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
520 CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
521 CONV_OP_SPECIALIZER(linalg::Conv2DOp);
522 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
523 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
524 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
525 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
526 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
527 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
528 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
529 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
530 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
531 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
532 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
533 CONV_OP_SPECIALIZER(linalg::Conv3DOp);
534 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfOp);
535 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfQOp);
536 CONV_OP_SPECIALIZER(linalg::Conv3DNcdhwFcdhwOp);
537 // -----------------------------
538 // Depthwise Convolution ops.
539 // -----------------------------
540 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
541 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
542 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
543 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
544 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
545 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
546 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
547 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
548 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcOp);
549 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNcdhwCdhwOp);
550 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
551 // -----------------------------
552 // Pooling ops.
553 // -----------------------------
554 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
555 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
556 CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
557 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
558 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
559 CONV_OP_SPECIALIZER(linalg::PoolingNchwSumOp);
560 CONV_OP_SPECIALIZER(linalg::PoolingNchwMaxOp);
561 CONV_OP_SPECIALIZER(linalg::PoolingNwcSumOp);
562 CONV_OP_SPECIALIZER(linalg::PoolingNcwSumOp);
563 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxOp);
564 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxUnsignedOp);
565 CONV_OP_SPECIALIZER(linalg::PoolingNcwMaxOp);
566 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinOp);
567 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinUnsignedOp);
568 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcSumOp);
569 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMaxOp);
570 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMinOp);
571#undef CONV_OP_SPECIALIZER
572 return failure();
573}
574
575} // namespace
576
577//===----------------------------------------------------------------------===//
578// Categorize linalg generic to named op where possible.
579//===----------------------------------------------------------------------===//
581 RewriterBase &rewriter, GenericOp genericOp,
583 // Unary elementwise - e.g. exp
584 if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
585 return specializeLinalgUnaryElementwise(rewriter, genericOp,
586 options.emitCategoryOps);
587 }
588
589 // Contraction - e.g. matmul
590 if (isaContractionOpInterface(genericOp)) {
591 return specializeLinalgContractions(rewriter, genericOp,
592 options.emitCategoryOps);
593 }
594
595 // Early exit in case of category specialization.
596 // TODO: Remove when matches for other ops account for both named and
597 // category.
598 if (options.emitCategoryOps)
599 return rewriter.notifyMatchFailure(
600 genericOp, "no matching category op specialization");
601
602 // Copy
603 if (isaCopyOpInterface(genericOp)) {
604 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
605 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
606 return namedOp;
607 }
608
609 // Fill
610 if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
611 // Always use the detected fill value, regardless of pattern
612 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
613 genericOp, *fillValue, genericOp.getDpsInits()[0]);
614 return namedOp;
615 }
616
617 // Broadcast
618 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
619 isaBroadcastOpInterface(genericOp);
620 if (equivalentToBroadcast) {
621 auto dims = *equivalentToBroadcast;
622 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
623 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
624 dims);
625 return namedOp;
626 }
627
628 // Transpose
629 std::optional<SmallVector<int64_t>> equivalentToTranspose =
630 isaTransposeOpInterface(genericOp);
631 if (equivalentToTranspose) {
632 auto permutation = *equivalentToTranspose;
633 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
634 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
635 permutation);
636 return namedOp;
637 }
638
639 // Elementwise Binary
640 if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
641 bool swap = areBinOpsSwapped(genericOp);
642 Operation *op = &genericOp.getBody()->front();
643 if (isa<arith::AddFOp>(op)) {
644 LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
645 return namedOp;
646 }
647 if (isa<arith::SubFOp>(op)) {
648 LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
649 return namedOp;
650 }
651 if (isa<arith::MulFOp>(op)) {
652 LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
653 return namedOp;
654 }
655 if (isa<arith::DivFOp>(op)) {
656 LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
657 return namedOp;
658 }
659 }
660
661 // Convolution - e.g. *conv/pooling*
662 if (isaConvolutionOpInterface(genericOp))
663 return specializeLinalgConvolutions(rewriter, genericOp);
664
665 return rewriter.notifyMatchFailure(genericOp,
666 "no matching named op specialization");
667}
668
669namespace {
670struct LinalgSpecializeGenericOpsPass
671 : public impl::LinalgSpecializeGenericOpsPassBase<
672 LinalgSpecializeGenericOpsPass> {
673
674 using impl::LinalgSpecializeGenericOpsPassBase<
675 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
676 void runOnOperation() override;
677};
678} // namespace
679
680void LinalgSpecializeGenericOpsPass::runOnOperation() {
681 RewritePatternSet patterns(&getContext());
684
685 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
686 signalPassFailure();
687}
688
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
static FailureOr< LinalgOp > specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:190
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:414
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch