MLIR 23.0.0git
Specialize.cpp
Go to the documentation of this file.
1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29#define DEBUG_TYPE "linalg-specialization"
30
31using namespace mlir;
32using namespace mlir::linalg;
33
34//===----------------------------------------------------------------------===//
35// Specialize linalg generic to elementwise ops.
36//===----------------------------------------------------------------------===//
37
38// Given an elementwise single binary linalg generic op, checks whether the
39// binary op accesses operands as swapped. e.g.
40// this differentiates between a linalg-generic body that contains:
41// ^bb0(%a: f32, %b: f32, %c : f32):
42// %0 = arith.subf %a, %b : f32
43// linalg.yield %0: f32
44// against:
45// ^bb0(%a: f32, %b: f32, %c : f32):
46// %0 = arith.subf %b, %a : f32
47// linalg.yield %0: f32
48// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
49static bool areBinOpsSwapped(GenericOp genericOp) {
50 Block *body = genericOp.getBody();
51 Operation *op = &body->front();
52 bool swapped = false;
53 if (op->getOpOperand(0).get() != body->getArgument(0)) {
54 swapped = true;
55 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
56 op->getOpOperand(1).get() == body->getArgument(0) &&
57 "binary op uses just one block arg");
58 }
59 return swapped;
60}
61
62// Given an elementwise single unary linalg generic op whose body operation is a
63// binary operation, check if one of its operands is a scalar value defined
64// outside the generic op, set its index, and return true. Otherwise return
65// false. The index is unique because the block argument is used at
66// least by one operand, as checked in `isaElemwiseSingleUnaryOpInterface`.
67//
68// Example:
69// %cst = arith.constant 3.14 : f32
70// %0 = linalg.generic { indexing_maps = [#mapA, #mapRes], ... }
71// ins(%A : tensor<?xf32>) outs(...) {
72// ^bb0(%a: f32, %out : f32):
73// %0 = arith.mulf %a, %cst : f32
74// linalg.yield %0: f32
75// } -> tensor<?xf32>
76// Here, the returned index is 1, and the generic op can be represented as
77// %0 = linalg.elementwise kind=#linalg.elementwise_kind<mul>
78// indexing_maps = [#mapA, affine_map<(d0) -> ()>, #mapRes]
79// ins(%A, %cst : tensor<?xf32>, f32) outs(...) -> tensor<?xf32>
80static bool findIndexOfScalarOperand(GenericOp genericOp, int &index) {
81 Block *body = genericOp.getBody();
82 Operation *op = &body->front();
83 for (auto [i, v] : llvm::enumerate(op->getOperands())) {
84 if (auto blockArg = dyn_cast<BlockArgument>(v);
85 blockArg && blockArg.getOwner() == body)
86 continue; // not an outside value...
87 index = i;
88 return true;
89 }
90 return false;
91}
92
93// Attempt to specialize unary or binary linalg.generic ops to named elementwise
94// ops or linalg.elementwise.
95//
96// Example:
97// %0 = linalg.generic {
98// indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
99// affine_map<(d0, d1) -> (d0, d1)>],
100// iterator_types = ["parallel", "parallel"]
101// } ins(%In : tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
102// ^bb0(%in: f32, %out: f32):
103// %1 = math.exp %in : f32
104// linalg.yield %1 : f32
105// } -> tensor<?x?xf32>
106//
107// is specialized to either
108// linalg.exp ins(...) outs(...) -> ...
109// or
110// linalg.elementwise kind=#linalg.elementwise_kind<exp> ...
111//
112// Only the category op can carry non-identity indexing maps; these are
113// transferred verbatim from the `genericOp`.
114//
115// In addition to the canonical forms used by the generalization path, this
116// function can handle the following variations:
117//
118// 1) Swapped operands in binary ops (see the `areBinOpsSwapped` helper)
119// 2) Unary generic ops with a binary body op (see the
120// `findIndexOfScalarOperand` helper)
121static FailureOr<LinalgOp> specializeLinalgElementwise(RewriterBase &rewriter,
122 GenericOp genericOp,
123 bool emitCategoryOp) {
124 bool hasNonIdentityMaps =
125 !llvm::all_of(genericOp.getIndexingMapsArray(),
126 [](AffineMap map) { return map.isIdentity(); });
127
128 // Early exit: Named ops cannot carry user-defined maps.
129 if (hasNonIdentityMaps && !emitCategoryOp)
130 return rewriter.notifyMatchFailure(
131 genericOp,
132 "non-identity indexing maps prevent specialization to named op");
133
134 // Classify the generic op.
135 bool isUnary = genericOp.getNumDpsInputs() == 1;
136 bool isBinary = genericOp.getNumDpsInputs() == 2;
137
138 // Will inspect the body operation to determine named op or elementwise kind.
139 Operation *op = &genericOp.getBody()->front();
140
141 // Detect variations from canonical forms.
142 bool hasSwappedOperands = isBinary && areBinOpsSwapped(genericOp);
143 int scalarOprIdx = -1;
144 bool hasScalarOperand = isUnary && op->getNumOperands() == 2 &&
145 findIndexOfScalarOperand(genericOp, scalarOprIdx);
146
147 // Helper to dispatch between named op and `linalg.elementwise`.
148 // Lambdas with explicit template parameter list are a C++20 feature, hence
149 // the dummy op object.
150 auto replaceOp = [&](auto namedOp, ElementwiseKind kind,
151 bool mayHoistScalarOperand = true) -> LinalgOp {
152 SmallVector<Value> inputs = genericOp.getDpsInputs();
153 if (hasSwappedOperands)
154 std::swap(inputs[0], inputs[1]);
155
156 LinalgOp newOp;
157 if (!emitCategoryOp) {
158 using NamedOpTy = decltype(namedOp);
159 if constexpr (!std::is_null_pointer_v<NamedOpTy>)
160 newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
161 genericOp.getDpsInits(),
163 else
164 llvm_unreachable("Missing named op type");
165 } else {
166 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
167 // Swap indexing maps, too.
168 if (hasSwappedOperands)
169 std::swap(indexingMaps[0], indexingMaps[1]);
170
171 // Represent unary generic op as a binary `linalg.elementwise` with a
172 // scalar operand and broadcasting map.
173 if (hasScalarOperand && mayHoistScalarOperand) {
174 // Adjust inputs and indexing maps accordingly.
175 inputs.insert(inputs.begin() + scalarOprIdx,
176 op->getOperand(scalarOprIdx));
177 auto scalarBroadcastMap =
178 AffineMap::get(genericOp.getNumParallelLoops(), /*symbolCount=*/0,
179 rewriter.getContext());
180 indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
181 scalarBroadcastMap);
182 }
183 newOp = ElementwiseOp::create(
184 rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
185 ElementwiseKindAttr::get(rewriter.getContext(), kind),
186 rewriter.getAffineMapArrayAttr(indexingMaps));
187 }
188
189 rewriter.replaceOp(genericOp, newOp);
190 return newOp;
191 };
192
193 if (isUnary) {
194 if (isa<math::ExpOp>(op))
195 return replaceOp(ExpOp{}, ElementwiseKind::exp);
196 if (isa<math::LogOp>(op))
197 return replaceOp(LogOp{}, ElementwiseKind::log);
198 if (isa<math::AbsFOp>(op))
199 return replaceOp(AbsOp{}, ElementwiseKind::abs);
200 if (isa<math::CeilOp>(op))
201 return replaceOp(CeilOp{}, ElementwiseKind::ceil);
202 if (isa<math::FloorOp>(op))
203 return replaceOp(FloorOp{}, ElementwiseKind::floor);
204 if (isa<arith::NegFOp>(op))
205 return replaceOp(NegFOp{}, ElementwiseKind::negf);
206 if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
207 if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
208 divOp.getLhs().getDefiningOp()))
209 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
210 return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
211 /*mayHoistScalarOperand=*/false);
212 }
213 if (isa<math::RoundOp>(op))
214 return replaceOp(RoundOp{}, ElementwiseKind::round);
215 if (isa<math::SqrtOp>(op))
216 return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
217 if (isa<math::RsqrtOp>(op))
218 return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
219 if (auto mulOp = dyn_cast<arith::MulFOp>(op);
220 mulOp && mulOp.getLhs() == mulOp.getRhs())
221 return replaceOp(SquareOp{}, ElementwiseKind::square);
222 if (isa<math::TanhOp>(op))
223 return replaceOp(TanhOp{}, ElementwiseKind::tanh);
224 if (isa<math::ErfOp>(op))
225 return replaceOp(ErfOp{}, ElementwiseKind::erf);
226
227 // The following ops only have the category (elementwise) form, but no
228 // linalg.* named op equivalent.
229 if (emitCategoryOp) {
230 if (isa<math::SinOp>(op))
231 return replaceOp(nullptr, ElementwiseKind::sin);
232 if (isa<math::CosOp>(op))
233 return replaceOp(nullptr, ElementwiseKind::cos);
234 if (isa<math::TanOp>(op))
235 return replaceOp(nullptr, ElementwiseKind::tan);
236 if (isa<math::AcosOp>(op))
237 return replaceOp(nullptr, ElementwiseKind::acos);
238 if (isa<math::AcoshOp>(op))
239 return replaceOp(nullptr, ElementwiseKind::acosh);
240 if (isa<math::AsinOp>(op))
241 return replaceOp(nullptr, ElementwiseKind::asin);
242 if (isa<math::AsinhOp>(op))
243 return replaceOp(nullptr, ElementwiseKind::asinh);
244 if (isa<math::AtanOp>(op))
245 return replaceOp(nullptr, ElementwiseKind::atan);
246 if (isa<math::AtanhOp>(op))
247 return replaceOp(nullptr, ElementwiseKind::atanh);
248 if (isa<math::Log10Op>(op))
249 return replaceOp(nullptr, ElementwiseKind::log10);
250 if (isa<math::Log1pOp>(op))
251 return replaceOp(nullptr, ElementwiseKind::log1p);
252 if (isa<math::Log2Op>(op))
253 return replaceOp(nullptr, ElementwiseKind::log2);
254 }
255
256 // At this point, we exhaustively checked the available unary named ops. The
257 // 1-input generic op might be representable as a `linalg.elementwise` that
258 // broadcasts a scalar operand. But if we can't emit the category op or
259 // don't have a scalar operand, exit now.
260 if (!emitCategoryOp || !hasScalarOperand)
261 return rewriter.notifyMatchFailure(
262 genericOp, "unary elementwise operation cannot be specialized to "
263 "named or category op");
264 }
265
266 // Boolean-typed `linalg.add` and `linalg.mul` require special handling.
267 bool allBool = llvm::all_of(op->getOperands(),
268 [](Value v) { return v.getType().isInteger(1); });
269
270 if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
271 (allBool && isa<arith::OrIOp>(op)))
272 return replaceOp(AddOp{}, ElementwiseKind::add);
273 if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
274 return replaceOp(SubOp{}, ElementwiseKind::sub);
275 if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
276 (allBool && isa<arith::AndIOp>(op)))
277 return replaceOp(MulOp{}, ElementwiseKind::mul);
278 if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
279 return replaceOp(DivOp{}, ElementwiseKind::div);
280 if (isa<arith::DivUIOp>(op))
281 return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
282 if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
283 return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
284 if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
285 return replaceOp(MinOp{}, ElementwiseKind::min_signed);
286 if (emitCategoryOp) {
287 // No named ops for unsigned maximum/minimum.
288 if (isa<arith::MaxUIOp>(op))
289 return replaceOp(nullptr, ElementwiseKind::max_unsigned);
290 if (isa<arith::MinUIOp>(op))
291 return replaceOp(nullptr, ElementwiseKind::min_unsigned);
292 }
293 if (isa<math::PowFOp>(op))
294 return replaceOp(PowFOp{}, ElementwiseKind::powf);
295
296 return rewriter.notifyMatchFailure(
297 genericOp,
298 "elementwise operation cannot be specialized to named or category op");
299}
300
301//===----------------------------------------------------------------------===//
302// Specialize linalg generic to matmul variants.
303//===----------------------------------------------------------------------===//
304/// Identifies linalg.generic that is essentially named op of the form:
305// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
306//
307// It is possible that a linalg.generic may be implementing a matmul but not
308// in a straight-forward way e.g. below is matrix multiply over some slice
309// ```
310// %0 = linalg.generic {
311// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
312// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
313// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
314// iterator_types = ["parallel", "parallel", "parallel"]}
315// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
316// outs(%C : tensor<20x20x20xf32>) {
317// ^bb0(%a: f32, %b: f32, %c : f32):
318// %mul = arith.mulf %a, %b : f32
319// %add = arith.addf %mul, %c : f32
320// linalg.yield %add : f32
321// } -> tensor<20x20x20xf32>
322// ```
323// It is not possible to represent above as named op.
324// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
325// not the same as linalg.generic above.
326namespace {
327enum class IndexMatchResult {
328 Match = 0, // identity map.
329 Transposed, // transposed map.
330 Mismatch // none of the above.
331};
332
333// Checks whether the input Affine `map` contains two consecutive dims that
334// can be interpreted as accessing a 2D matrix. It is assumed that the row
335// column dimension are adjacent axis (in this order) and start at
336// `rowDimIdx` in the input map.
337//
338// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
339// whether the map of A is identity (match), transposed, or something
340// completely different (mis-match). Similar for B and C.
341static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
342 unsigned expectedPosOfRowDim,
343 unsigned expectedPosOfColDim) {
344 // Get the matrix multiply indices. They are past the batch indices.
345 auto exprOfRowDim = map.getResults()[rowDimIdx];
346 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
347
348 // They should be pure dimension ids.
349 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
350 exprOfColDim.getKind() != AffineExprKind::DimId)
351 return IndexMatchResult::Mismatch;
352
353 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
354 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
355
356 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
357 return IndexMatchResult::Match;
358
359 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
360 return IndexMatchResult::Transposed;
361
362 return IndexMatchResult::Mismatch;
363}
364
365// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
366// All the variants expressed as pseudo regular expression:
367// `linalg.{batch_}?matmul` have same number of ins/out, so it's easy to
368// stamp different versions.
369// `castTy` is an optional type function that indicates whether (and which) cast
370// attribute is needed for the named matmul op variant.
371template <typename NamedOpTy>
372static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
373 std::optional<TypeFn> castTy,
374 ArrayRef<AffineMap> indexingMaps) {
376 // Only explicitly specify the cast attribute for unsigned cast; signed is
377 // the default for linalg.matmul/linalg.batch_matmul.
378 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
379 auto castAttr = rewriter.getNamedAttr(
380 "cast", TypeFnAttr::get(rewriter.getContext(), *castTy));
381 attributes.push_back(castAttr);
382 }
383
384 // Set the original generic's maps to preserve operand indexing semantics like
385 // transposition.
386 SmallVector<Attribute, 3> indexingMapsAttrVal =
387 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
388 return AffineMapAttr::get(map);
389 });
390 auto indexingMapsAttr = rewriter.getNamedAttr(
391 "indexing_maps", rewriter.getArrayAttr(indexingMapsAttrVal));
392 attributes.push_back(indexingMapsAttr);
393
394 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
395 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
396 ValueRange{op.getDpsInits()[0]}, attributes);
397
398 return namedOp;
399}
400
401// Returns the cast type to use for a matmul-like named op. If the generic
402// contains casts that cannot be represented (e.g. output casts or mixed
403// signedness), return std::nullopt.
404static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
405 bool foundCastForMatmulOutput = false;
406 SmallVector<TypeFn> castTyFns;
407 genericOp.getBody()->walk([&](CastOpInterface castOp) {
408 // Collect forward slice of the cast op to check if it is for the matmul
409 // output.
410 SetVector<Operation *> forwardSlice;
411 getForwardSlice(castOp, &forwardSlice);
412
413 // If there is no multiplication op in the forward slice, then this cast
414 // op is for the matmul output. Cast ops on matmul output cannot be
415 // expressed by the matmul op variant.
416 if (!llvm::any_of(forwardSlice, [](Operation *op) {
417 // We check explicitly for these multiplication ops in
418 // `specializeLinalgContractions()` to infer matmul-like ops.
419 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
420 })) {
421 foundCastForMatmulOutput = true;
422 return WalkResult::interrupt();
423 }
424
425 // Determine the cast type.
426 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
427 castTyFns.push_back(TypeFn::cast_unsigned);
428 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
429 castTyFns.push_back(TypeFn::cast_signed);
430
431 return WalkResult::advance();
432 });
433
434 if (foundCastForMatmulOutput)
435 return std::nullopt;
436
437 if (!castTyFns.empty()) {
438 // If there were multiple different cast types found, then we can't express
439 // them using matmul-like ops. They only allow a single cast type for all
440 // inputs.
441 if (!llvm::all_equal(castTyFns))
442 return std::nullopt;
443 return castTyFns.front();
444 }
445
446 // Default to signed cast for matmul-like ops.
447 return TypeFn::cast_signed;
448}
449
450static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
451 GenericOp genericOp,
452 std::optional<TypeFn> castTy,
453 ContractionDimensions &dims) {
454 // Should all be rank 4 and dim 6
455 auto indexingMaps = genericOp.getIndexingMapsArray();
456 if (llvm::any_of(indexingMaps, [](AffineMap m) {
457 return m.getResults().size() != 4 || m.getNumDims() != 6;
458 }))
459 return failure();
460
461 auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
462 auto aInner = matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
463
464 auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
465 auto bInner = matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
466
467 auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
468 auto cInner = matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
469
470 if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
471 return failure();
472 if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
473 return failure();
474
475 SmallVector<AffineMap> namedOpMaps = {indexingMaps[0], indexingMaps[1],
476 indexingMaps[2]};
477
478 return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
479 namedOpMaps);
480}
481
482static bool isSupportedContractionPair(Operation *first, Operation *second) {
483 if (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second))
484 return true;
485 if (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second))
486 return true;
487 if (isa<complex::MulOp>(first) && isa<complex::AddOp>(second))
488 return true;
489 if (isa<arith::AndIOp>(first) && isa<arith::OrIOp>(second) &&
490 first->getResult(0).getType().isInteger(1))
491 return true;
492
493 return false;
494}
495
496// Converts linalg.generic to named linalg.*matmul* where possible.
497static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
498 GenericOp genericOp,
499 bool emitCategoryOp) {
500 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
501 return failure();
502
503 // Early exit if not projected permutations.
504 auto mapRange = genericOp.getIndexingMapsArray();
505 if (llvm::any_of(mapRange,
506 [](AffineMap m) { return !m.isProjectedPermutation(); }))
507 return failure();
508
509 // Only contractions that can be represented by named linalg ops are
510 // eligible for specialization:
511 // - mul + add (floating-point, integer, complex)
512 // - and + or (bool)
513 if (!mlir::linalg::detail::isContractionBody(*genericOp.getBlock(),
514 isSupportedContractionPair))
515 return failure();
516
517 // Determine the cast type for the named matmul op, or bail out if casts
518 // cannot be represented by the named op.
519 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
520 if (!castTy)
521 return rewriter.notifyMatchFailure(
522 genericOp, "contains invalid cast ops for the named matmul op");
523
524 // In case of category op, wider range of variants is supported.
525 if (emitCategoryOp)
526 return replaceWithMatmulVariant<ContractOp>(
527 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
528
529 // Further checks for named variants.
530 //
531 // Linalg generic contraction can be across multiple axis e.g.
532 // ```
533 // linalg.generic
534 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
535 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
536 // affine_map<(m, n, k1, k2) -> (m, n)>],
537 // iterator_types = ["parallel", "parallel",
538 // "reduction", "reduction"]}
539 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
540 // outs(%C : tensor<10x40xf32>) {
541 // ^bb0(%a: f32, %b: f32, %c: f32):
542 // %1 = arith.mulf %a, %b : f32
543 // %2 = arith.addf %c, %1 : f32
544 // linalg.yield %2 : f32
545 // } -> tensor<10x40xf32>
546 // ```
547 // In above contraction, there are two reduction dimensions {k1, k2}
548 // and although a valid linalg contraction, it is not a named-op
549 // matrix multiply kind. Therefore, reject multi-dim reduction.
550 auto res = inferContractionDims(genericOp);
551 if (!succeeded(res))
552 return failure();
553 auto dims = *res;
554 if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
555 return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
556 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
557 return failure();
558
559 // Check rank of operands
560 auto indexingMaps = genericOp.getIndexingMapsArray();
561 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
562 return m.getResults().size() !=
563 dims.batch.size() + 2 /* any two of {m,n,k} */;
564 }))
565 return failure();
566
567 auto numOfBatchDims = dims.batch.size();
568 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
569 return failure();
570
571 if (numOfBatchDims) {
572 // Each operand in a linalg generic contraction could express different
573 // permutations for its batch dimension. But for named op it must be
574 // identity since separate maps are not specified.
575 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
576 for (unsigned i = 0; i < numOfBatchDims; ++i) {
577 auto expr = m.getResults()[i];
578 if (expr.getKind() != AffineExprKind::DimId ||
579 cast<AffineDimExpr>(expr).getPosition() != i)
580 return true;
581 }
582 return false;
583 }))
584 return failure();
585 }
586
587 auto a =
588 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
589 auto b =
590 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
591 auto c =
592 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
593
594 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
595 return failure();
596
597 // Build indexing maps for the named op in its canonical dimension ordering
598 auto *ctx = genericOp.getContext();
599 unsigned numLoopDims = numOfBatchDims + 3;
600 unsigned mIdx = numOfBatchDims;
601 unsigned nIdx = mIdx + 1;
602 unsigned kIdx = mIdx + 2;
603
604 // TODO: add support for indexing_maps with broadcasts.
605 auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
606 SmallVector<unsigned> tensorDims;
607 for (unsigned i = 0; i < numOfBatchDims; ++i)
608 tensorDims.push_back(i);
609 if (match == IndexMatchResult::Transposed)
610 llvm::append_values(tensorDims, colIdx, rowIdx);
611 else
612 llvm::append_values(tensorDims, rowIdx, colIdx);
613 return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
614 };
615
616 auto mapA = makeMap(a, mIdx, kIdx);
617 auto mapB = makeMap(b, kIdx, nIdx);
618 auto mapC = makeMap(c, mIdx, nIdx);
619
620 SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
621
622 // Codegen the different matmul variants.
623 if (numOfBatchDims) {
624 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
625 namedOpMaps);
626 }
627 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
628 namedOpMaps);
629}
630
631/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
632/// with `dilations` and `strides`.
633template <typename ConvOpTy>
634static FailureOr<LinalgOp>
635specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
636 ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
637 SmallVector<Value> inputs = genericOp.getDpsInputs();
638 ValueRange outputs = genericOp.getDpsInits();
639 SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
640 ? TypeRange(ValueRange(outputs))
641 : TypeRange{};
642 LinalgOp namedOp;
643 // Ops with no dilations and no strides.
644 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
645 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
646 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
647 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
648 inputs, outputs);
649 } else {
650 Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
651 Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
652 namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
653 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
654 }
655 return namedOp;
656}
657
658/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
659static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
660 GenericOp genericOp) {
661#define CONV_OP_SPECIALIZER(ConvOpTy) \
662 if (std::optional<DilationsAndStrides> convParams = \
663 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
664 return specializeToConvOp<ConvOpTy>( \
665 rewriter, genericOp, convParams->dilations, convParams->strides); \
666 // -----------------------------
667 // Convolution ops.
668 // -----------------------------
669 CONV_OP_SPECIALIZER(linalg::Conv1DOp);
670 CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
671 CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
672 CONV_OP_SPECIALIZER(linalg::Conv2DOp);
673 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
674 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
675 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
676 CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
677 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
678 CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
679 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
680 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
681 CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
682 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
683 CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
684 CONV_OP_SPECIALIZER(linalg::Conv3DOp);
685 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfOp);
686 CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfQOp);
687 CONV_OP_SPECIALIZER(linalg::Conv3DNcdhwFcdhwOp);
688 // -----------------------------
689 // Depthwise Convolution ops.
690 // -----------------------------
691 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
692 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
693 CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
694 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
695 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
696 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
697 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
698 CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
699 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcOp);
700 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNcdhwCdhwOp);
701 CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
702 // -----------------------------
703 // Pooling ops.
704 // -----------------------------
705 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
706 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
707 CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
708 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
709 CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
710 CONV_OP_SPECIALIZER(linalg::PoolingNchwSumOp);
711 CONV_OP_SPECIALIZER(linalg::PoolingNchwMaxOp);
712 CONV_OP_SPECIALIZER(linalg::PoolingNwcSumOp);
713 CONV_OP_SPECIALIZER(linalg::PoolingNcwSumOp);
714 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxOp);
715 CONV_OP_SPECIALIZER(linalg::PoolingNwcMaxUnsignedOp);
716 CONV_OP_SPECIALIZER(linalg::PoolingNcwMaxOp);
717 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinOp);
718 CONV_OP_SPECIALIZER(linalg::PoolingNwcMinUnsignedOp);
719 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcSumOp);
720 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMaxOp);
721 CONV_OP_SPECIALIZER(linalg::PoolingNdhwcMinOp);
722#undef CONV_OP_SPECIALIZER
723 return failure();
724}
725
726} // namespace
727
728//===----------------------------------------------------------------------===//
729// Categorize linalg generic to named op where possible.
730//===----------------------------------------------------------------------===//
732 RewriterBase &rewriter, GenericOp genericOp,
734 // Elementwise - e.g. exp, add
735 if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps) ||
736 isaElemwiseSingleBinaryOpInterface(genericOp, options.emitCategoryOps)) {
737 return specializeLinalgElementwise(rewriter, genericOp,
738 options.emitCategoryOps);
739 }
740
741 // Contraction - e.g. matmul
742 if (isaContractionOpInterface(genericOp)) {
743 return specializeLinalgContractions(rewriter, genericOp,
744 options.emitCategoryOps);
745 }
746
747 // Early exit in case of category specialization.
748 // TODO: Remove when matches for other ops account for both named and
749 // category.
750 if (options.emitCategoryOps)
751 return rewriter.notifyMatchFailure(
752 genericOp, "no matching category op specialization");
753
754 // Copy
755 if (isaCopyOpInterface(genericOp)) {
756 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
757 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
758 return namedOp;
759 }
760
761 // Fill
762 if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
763 // Always use the detected fill value, regardless of pattern
764 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
765 genericOp, *fillValue, genericOp.getDpsInits()[0]);
766 return namedOp;
767 }
768
769 // Broadcast
770 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
771 isaBroadcastOpInterface(genericOp);
772 if (equivalentToBroadcast) {
773 auto dims = *equivalentToBroadcast;
774 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
775 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
776 dims);
777 return namedOp;
778 }
779
780 // Transpose
781 std::optional<SmallVector<int64_t>> equivalentToTranspose =
782 isaTransposeOpInterface(genericOp);
783 if (equivalentToTranspose) {
784 auto permutation = *equivalentToTranspose;
785 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
786 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
787 permutation);
788 return namedOp;
789 }
790
791 // Convolution - e.g. *conv/pooling*
792 if (isaConvolutionOpInterface(genericOp))
793 return specializeLinalgConvolutions(rewriter, genericOp);
794
795 return rewriter.notifyMatchFailure(genericOp,
796 "no matching named op specialization");
797}
798
799namespace {
800struct LinalgSpecializeGenericOpsPass
802 LinalgSpecializeGenericOpsPass> {
803
805 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
806 void runOnOperation() override;
807};
808} // namespace
809
810void LinalgSpecializeGenericOpsPass::runOnOperation() {
811 RewritePatternSet patterns(&getContext());
814
815 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
816 signalPassFailure();
817}
818
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static bool findIndexOfScalarOperand(GenericOp genericOp, int &index)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
static FailureOr< LinalgOp > specializeLinalgElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:190
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
unsigned getNumOperands()
Definition Operation.h:371
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:413
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch