MLIR 23.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 if (auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
143 return SmallVector<int64_t>(broadcastOp.getDimensions().begin(),
144 broadcastOp.getDimensions().end());
145
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
147 if (!op)
148 return std::nullopt;
149
150 // Structural.
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
153 return std::nullopt;
154
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
159 return std::nullopt;
160
161 // Check output is identity map. Broadcast could additionally be
162 // employing permutation of indices and that would be expressible
163 // in linalg.generic but is not expressible for named broadcast op.
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
166 return std::nullopt;
167
168 SmallVector<int64_t> position;
169 auto srcMap = op.getIndexingMapsArray()[0];
170
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
172 return std::nullopt;
173
174 // Check input map is monotonically increasing DimIds.
175 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
177 if (!expr)
178 return std::nullopt;
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
181 return std::nullopt;
182 position.push_back(expr.getPosition());
183 }
184
185 SmallVector<int64_t> broadcastedDims;
186 auto numDims = srcMap.getNumDims();
187 // This is quadratic but number of items is generally small.
188 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
191 }
192 return broadcastedDims;
193}
194
195//===----------------------------------------------------------------------===//
196// TransposeOpInterface implementation
197//===----------------------------------------------------------------------===//
198std::optional<SmallVector<int64_t>>
200 // To specialize as a transpose op, the genericOp must be
201 // all parallel loops, single input, single output, and its body
202 // should be just a yield op, yielding input as output as is (no compute).
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
205 return std::nullopt;
206
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
209 return std::nullopt;
210
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
213
214 // linalg.transpose permutes the dimensions of input using this
215 // rule: dim(result, i) = dim(input, permutation[i])
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
217 return std::nullopt;
218
219 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
220 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
223 }
224 return permutation;
225}
226
227//===----------------------------------------------------------------------===//
228// Elementwise Single Unary/Binary-OpInterface implementation
229//===----------------------------------------------------------------------===//
230static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
231 unsigned arity) {
232 // Check all loops are parallel.
233 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
234 return false;
235
236 // Check there are arity-inputs, 1-output and all are identity-maps.
237 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
238 !llvm::all_of(op.getIndexingMapsArray(),
239 [](AffineMap map) { return map.isIdentity(); }))
240 return false;
241
242 // Init should not be referenced for elementwise operations.
243 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
244 return false;
245
246 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
247 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
248 // the body, where the first is the elementwise single op and the second a
249 // yield.
250 Block *body = op.getBody();
251 if (body->getOperations().size() != 2)
252 return false;
253
254 // The payload op must have one result and at least arity-many operands
255 // (otherwise not all inputs can be used). It can have additional operands
256 // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
257 // use an input more than once (e.g. mul(x, x) for linalg.square).
258 Operation *oper = &body->front();
259 if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
260 return false;
261
262 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
263 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
264 yieldOp->getOperand(0).getDefiningOp() != oper);
265}
266
267bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
268 // All basic elemwise checks.
270 return false;
271
272 // Check input is actully used.
273 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
274 return false;
275 return true;
276}
277
278bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
280 return false;
281
282 // Check both inputs are used (elementwise).
283 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
284 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
285 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
286 !op.payloadUsesValueFromOperand(inputOpOperand1));
287}
288
289//===----------------------------------------------------------------------===//
290// ContractionOpInterface implementation
291//===----------------------------------------------------------------------===//
292
293/// If the value is defined by a chain of unary side effect-free, go up the
294/// use-def chain until the first value that isn't defined by such an op.
295// TODO: relax to multi-operands with constants, which are technically unary ops
296// as needed (e.g. add5).
298 Operation *op = value.getDefiningOp();
299 while (op && op->getNumOperands() == 1) {
300 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
301 if (!iface || !iface.hasNoEffect())
302 break;
303 value = op->getOperand(0);
304 op = value.getDefiningOp();
305 }
306 return value;
307}
308
310 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
311 llvm::raw_ostream &errs) {
312 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
313 errs << "no terminator in the block";
314 return false;
315 }
316
317 if (block.getNumArguments() != 3) {
318 errs << "expected block with 3 arguments";
319 return false;
320 }
321
322 Operation *terminator = block.getTerminator();
323 if (terminator->getNumOperands() != 1) {
324 errs << "expected terminator with 1 operand";
325 return false;
326 }
327
328 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
329 Operation *reductionOp = yielded.getDefiningOp();
330 if (!reductionOp || reductionOp->getNumResults() != 1 ||
331 reductionOp->getNumOperands() != 2) {
332 errs << "expected reduction op to be binary";
333 return false;
334 }
335
336 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
337 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
338
339 if (reductionLHS != block.getArgument(2) &&
340 reductionRHS != block.getArgument(2)) {
341 errs << "expected reduction to take block argument #2 as one of the "
342 "operands (modulo unary casts)";
343 return false;
344 }
345
346 Value contributed = getSourceSkipUnary(
347 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
348 Operation *elementwiseOp = contributed.getDefiningOp();
349 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
350 elementwiseOp->getNumOperands() != 2) {
351 errs << "expected elementwise op to be binary";
352 return false;
353 }
354
355 if (!isaPair(elementwiseOp, reductionOp)) {
356 errs << "expected reduction/elementwise op kind not satisfied";
357 return false;
358 }
359
360 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
361 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
362 if ((elementwiseLHS == block.getArgument(0) &&
363 elementwiseRHS == block.getArgument(1)) ||
364 (elementwiseLHS == block.getArgument(1) &&
365 elementwiseRHS == block.getArgument(0))) {
366 return true;
367 }
368
369 errs << "expected elementwise op to apply to block arguments (modulo unary "
370 "casts)";
371 return false;
372}
373
374/// Returns true if the two operations are of the kinds specified by a pair of
375/// consecutive template arguments.
376template <typename AddOpTy, typename MulOpTy, typename... Args>
378 static_assert(sizeof...(Args) % 2 == 0,
379 "expected an even number of template arguments");
380 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
381 return true;
382
383 if constexpr (sizeof...(Args) > 0)
385 else
386 return false;
387}
388
389/// Returns true if the block is a body of a contraction with the kinds of
390/// operations given pairwise by template arguments.
391template <typename... Args>
395
396/// Given an `indexingMap` and its corresponding `iterators`, returns
397/// the positions of the iterators of type `iter` that are indexed by
398/// the `indexingMap` as a permutation. This is useful to infer various
399/// subcomputations on a `LinalgOp`. This is performed by looking up
400/// each result in the `indexingMap` and determining whether:
401/// - It is a single AffineDimExpr.
402/// - It is the only result involving this AffineDimExpr.
403static llvm::SmallDenseSet<int64_t>
406 utils::IteratorType iter) {
407 assert(iterators.size() == indexingMap.getNumDims());
408 llvm::SmallDenseSet<int64_t> res;
409 for (AffineExpr e : indexingMap.getResults()) {
410 if (auto d = dyn_cast<AffineDimExpr>(e)) {
411 if (iterators[d.getPosition()] == iter &&
412 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
413 return e.isFunctionOfDim(d.getPosition());
414 }) == 1)
415 res.insert(d.getPosition());
416 }
417 }
418 return res;
419}
420
421namespace {
422auto par = utils::IteratorType::parallel;
423auto red = utils::IteratorType::reduction;
424} // namespace
425
426/// Infer the iterator types from the init affine map. This looks at which dims
427/// are present in the map results, and returns an iterator types array with
428/// parallel types for dims that are present, and reduction types for dims that
429/// are not present.
430static FailureOr<SmallVector<utils::IteratorType>>
432 if (!map.isProjectedPermutation())
433 return failure();
434 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
435 for (auto expr : map.getResults())
436 if (auto dim = dyn_cast<AffineDimExpr>(expr))
437 iterators[dim.getPosition()] = par;
438 return iterators;
439}
440
441/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
442/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
443/// 1. The m dimension is involved in an outer-product along LHS
444/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
445/// 2. The n dimension is involved in an outer-product along RHS
446/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
447/// 3. The k dimension appears as a permutation on LHS and RHS.
448/// 4. m, n and k appear only once in any given indexing.
449/// 5. Optional batch dimensions that appear in all operands are captured.
450/// This allows e.g. detecting that some contraction is embedded within
451/// `linalgOp` with some orthogonal heuristic.
452static FailureOr<ContractionDimensions>
455 llvm::SmallDenseSet<int64_t> a =
456 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
457 llvm::SmallDenseSet<int64_t> b =
458 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
459 llvm::SmallDenseSet<int64_t> c =
460 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
461
462 // A & C - B are the iterators involved in an outer-product along A (the LHS).
463 llvm::SmallDenseSet<int64_t> ac = a;
464 llvm::set_intersect(ac, c);
465 llvm::set_subtract(ac, b);
466 // B & C - A are the iterators involved in an outer-product along B (the RHS).
467 llvm::SmallDenseSet<int64_t> bc = b;
468 llvm::set_intersect(bc, c);
469 llvm::set_subtract(bc, a);
470 // A & B & C are the "batch" dimensions.
471 llvm::SmallDenseSet<int64_t> batches = a;
472 llvm::set_intersect(batches, b);
473 llvm::set_intersect(batches, c);
474
475 // A & B red are the reduction dimensions.
476 llvm::SmallDenseSet<int64_t> ra =
477 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
478 llvm::SmallDenseSet<int64_t> rb =
479 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
480 llvm::set_intersect(ra, rb);
481
482 // Return each set in sorted order.
483 ContractionDimensions dimensions{
484 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
485 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
486 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
487 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
488 llvm::sort(dimensions.batch);
489 llvm::sort(dimensions.m);
490 llvm::sort(dimensions.n);
491 llvm::sort(dimensions.k);
492 return dimensions;
493}
494
495FailureOr<ContractionDimensions>
497 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
498 return failure();
499 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
500 linalgOp.getIteratorTypesArray());
501}
502
503FailureOr<ContractionDimensions>
505 if (indexingMaps.size() != 3)
506 return failure();
507 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
508 if (failed(iterators))
509 return failure();
510 return inferContractionDimsImpl(indexingMaps, iterators.value());
511}
512
513namespace mlir::linalg::detail {
522} // namespace mlir::linalg::detail
523
527 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
528 if (!linalgOp)
530 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
532 auto mapRange = linalgOp.getIndexingMapsArray();
533 if (linalgOp.getNumReductionLoops() == 0)
535 if (llvm::any_of(mapRange,
536 [](AffineMap m) { return !m.isProjectedPermutation(); }))
538 // TODO: more fields than add/mul.
539 // clang-format off
541 arith::MulFOp, arith::AddFOp,
542 arith::MulIOp, arith::AddIOp,
543 complex::MulOp, complex::AddOp,
544 arith::AndIOp, arith::OrIOp>(
545 *linalgOp.getBlock())) {
547 }
548 // clang-format on
549
550 if (dimensions) {
551 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
552 assert(succeeded(res) && "unexpected failure to infer contraction dims");
553 *dimensions = *res;
554 }
556}
557
558StringRef
560 switch (res) {
562 return "expected a LinalgOp";
564 return "expected op with 2 inputs and 1 output";
566 return "expected at least 1 reduction";
568 return "expected indexing maps to be projected permutations";
570 return "expected add/mul op in the body";
572 return "";
573 }
574 llvm_unreachable("unhandled MatchContractionResult case");
575}
576
578 if (!linalgOp)
579 return false;
580 Operation *op = linalgOp.getOperation();
581 return isa<ContractionOpInterface>(op) ||
584}
585
586/// Verify that a LinalgOp `op` is a contraction.
587/// A Linalg contraction is defined in general terms:
588/// 1. Has 2 input and 1 output shapes.
589/// 2. Has at least one reduction dimension.
590/// 3. Has only projected permutation indexing maps.
591/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
592/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
593/// operations that may change the type (e.g. for mixed-precision).
594/// As a consequence, when vectorization of such an op occurs, the only special
595/// behavior is that the (unique) MulOpType is vectorized into a
596/// `vector.contract`. All other ops are handled in a generic fashion.
597/// In the future, we may wish to allow more input arguments and elementwise and
598/// constant operations that do not involve the reduction dimension(s).
605
606//===----------------------------------------------------------------------===//
607// ConvolutionOpInterface implementation
608//===----------------------------------------------------------------------===//
609
610/// Of the given two expressions returns one that is of type T (`lhs` gets
611/// preference over `rhs`)
612template <typename T>
614 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
615}
616
617namespace {
618/// Walk the indexing expressions for input of a convolution operation to verify
619/// its of the right form, either
620/// - AffineDimExpr
621/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
622/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
623///
624/// classifies the AffineDimExpr as convolved dimensions or unconvolved
625/// dimensions and verifies each dimension occurs only once.
626struct ConvAccessExprWalker
627 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
628 // Stores dimensions used in expressions of the above form.
629 llvm::SmallDenseSet<int64_t> convolvedDims;
630 // Stores the dual mapping between LHS and RHS of convolution exprs.
631 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
632 // Stores single use dimensions used by an AffineDimExpr.
633 llvm::SmallDenseSet<int64_t> unConvolvedDims;
634 // Stores a mapping from convolved dims to their coefficient.
635 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
636
637 // Removes dims with multiple uses in the source input map from dimension
638 // sets tracked by this walker.
639 void clearMultiUseDims(AffineMap map) {
640 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
641 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
642 return e.isFunctionOfDim(dimPos);
643 }) > 1) {
644 convolvedDims.erase(dimPos);
645 unConvolvedDims.erase(dimPos);
646 // If a duplicate dim is marked as convolved, the pair of the duplicate
647 // dim must be removed from the map as well.
648 auto it = convolvedDimMapping.find(dimPos);
649 if (it != convolvedDimMapping.end()) {
650 int64_t pairedDim = it->second;
651 convolvedDims.erase(pairedDim);
652 unConvolvedDims.erase(pairedDim);
653 strideAndDilationMapping.erase(pairedDim);
654 convolvedDimMapping.erase(dimPos);
655 convolvedDimMapping.erase(pairedDim);
656 }
657 }
658 }
659 }
660
661 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
662 unsigned position = dimExpr.getPosition();
663 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
664 return failure();
665 }
666 unConvolvedDims.insert(position);
667 return success();
668 }
669
670 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
671
672 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
673
674 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
675 // In pre-order visit, top level op has to be an add op.
676 if (binaryExpr.getKind() != AffineExprKind::Add)
677 return failure();
678 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
679 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
680 if (failed(lhsDimPos) || failed(rhsDimPos))
681 return failure();
682 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
683 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
684 return success();
685 }
686
687 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
688 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
689 int64_t dim = dimExpr.getPosition();
690 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
691 return failure();
692 // Stride/dilation for this dim is implicitly 1.
693 strideAndDilationMapping[dim] =
695 convolvedDims.insert(dim);
696 return dim;
697 }
698 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
699 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
700 return failure();
701 auto lhsExpr = symbolMulExpr.getLHS();
702 auto rhsExpr = symbolMulExpr.getRHS();
703 // Check for symbol expression.
704 AffineExpr mulExpr =
706 // If there was no symbol expr, check for constant expression.
707 if (!mulExpr) {
708 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
709 }
710 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
711 if (!mulExpr || !dimExpr)
712 return failure();
713 int64_t dim = dimExpr.getPosition();
714 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
715 return failure();
716 strideAndDilationMapping[dim] = mulExpr;
717 convolvedDims.insert(dim);
718 return dim;
719 }
720 return failure();
721 }
722};
723} // namespace
724
725static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
726 assert(map.isProjectedPermutation() &&
727 "expected map to have projected permutations");
728 llvm::SmallDenseSet<int64_t> preservedDims;
729 for (auto expr : map.getResults())
730 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
731 return preservedDims;
732}
733
737 for (auto e : exprs) {
738 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
739 assert(constantExpr && "Found non-constant stride/dilation");
740 vals.push_back(constantExpr.getValue());
741 }
742 return vals;
743}
744
745/// Classifies dimensions in the `linalgOp` used by a convolution
746/// subcomputation, as captured by `inputExprWalker`. If
747/// `allowEmptyConvolvedDims` is not set this will fail if there is not
748/// at least one convolved dimension pair (output image + filter loop).
749///
750/// The returned dimensions are ordered as follows:
751/// - `outputImage` is sorted by dimension index.
752/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
753/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
754/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
755/// - `strides[i]` corresponds to `outputImage[i]`.
756/// - `dilations[i]` corresponds to `filterLoop[i]`.
757/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
758static FailureOr<ConvolutionDimensions>
759inferConvolutionDimsImpl(LinalgOp linalgOp,
760 ConvAccessExprWalker &inputExprWalker,
761 bool allowEmptyConvolvedDims) {
762 auto filterMap =
763 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
764 auto outputMap =
765 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
766 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
767 filterMap, linalgOp.getIteratorTypesArray(), par);
768 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
769 outputMap, linalgOp.getIteratorTypesArray(), par);
770
771 // unConvolvedDims & outputDims - filterDims are the batch iterators.
772 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
773 llvm::set_intersect(batch, outputDims);
774 llvm::set_subtract(batch, filterDims);
775
776 // convolvedDims & outputDims are the output image iterators.
777 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
778 llvm::set_intersect(oi, outputDims);
779
780 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
781 llvm::SmallDenseSet<int64_t> oc = filterDims;
782 llvm::set_intersect(oc, outputDims);
783 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
784
785 // filterDims & outputDims & unConvolvedDims are the depth iterators.
786 llvm::SmallDenseSet<int64_t> depth = filterDims;
787 llvm::set_intersect(depth, outputDims);
788 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
789
790 llvm::SmallDenseSet<int64_t> filterReducedDims =
792 linalgOp.getIteratorTypesArray(), red);
793
794 // convolvedDims & filterReducedDims are the filter loop iterators.
795 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
796 llvm::set_intersect(fl, filterReducedDims);
797
798 // unConvolvedDims & filterReducedDims are the input channel iterators.
799 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
800 llvm::set_intersect(ic, filterReducedDims);
801
802 if (oi.empty() && !allowEmptyConvolvedDims)
803 return failure();
804
805 // Return each set in sorted order, with outputImage and filterLoop
806 // ordered so that outputImage[i] pairs with filterLoop[i].
807 ConvolutionDimensions dimensions{
808 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
809 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
810 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
811 /*filterLoop=*/SmallVector<unsigned, 2>{},
812 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
813 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
814 /*strides=*/SmallVector<int64_t, 2>{},
815 /*dilations=*/SmallVector<int64_t, 2>{}};
816 llvm::sort(dimensions.batch);
817 llvm::sort(dimensions.outputImage);
818 llvm::sort(dimensions.outputChannel);
819 llvm::sort(dimensions.inputChannel);
820 llvm::sort(dimensions.depth);
821 // Order filterLoop to match the pairing with outputImage. Each outputImage
822 // dimension has a corresponding filterLoop dimension from the convolution
823 // access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
824 // filterLoop[i].
825 for (unsigned oiDim : dimensions.outputImage)
826 dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
827
828 // Use the op carried strides/dilations attribute if present.
829 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
830 if (!nativeStrides) {
831 SmallVector<AffineExpr, 2> strideExprs;
832 for (unsigned oiDim : dimensions.outputImage)
833 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
834 dimensions.strides = getConstantsFromExprList(strideExprs);
835 } else {
836 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
837 }
838 auto nativeDilations =
839 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
840 if (!nativeDilations) {
841 SmallVector<AffineExpr, 2> dilationExprs;
842 for (unsigned flDim : dimensions.filterLoop)
843 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
844 dimensions.dilations = getConstantsFromExprList(dilationExprs);
845 } else {
846 dimensions.dilations =
847 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
848 }
849 return dimensions;
850}
851
852/// Find at least 1 parallel (output_image) and reduction (filter_loop)
853/// dimension candidates that form a convolution subcomputation within
854/// `linalgOp`. The LHS is assumed to be the convolution input while the
855/// RHS is assumed as the filter.
856/// These dimensions are such that:
857/// 1. Optional batch dimensions that appear in the input and filter.
858/// 2. The output_image dimension is involved in a cross-correlation along LHS
859/// (i.e. it is a permutation on RES and LHS and has an associated
860/// filter_loop in RHS).
861/// 3. Optional output_channel dimension is involved in an outer-product along
862/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
863/// LHS).
864/// 4. Optional input_channel dimension appears as a permutation on LHS and
865/// RHS.
866/// 5. The filter_loop dimension appears as a permutation on the RHS and
867/// represents the shape of the kernel cross-correlated along a
868/// corresponding output_image dim.
869/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
870/// 7. All dimensions appear only once in any given indexing map.
871/// This allows e.g. detecting that some convolution is embedded within
872/// `linalgOp` with some orthogonal heuristic.
873///
874/// The `outputImage` and `filterLoop` arrays are ordered such that
875/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
876/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
877/// dimension 2). Other dimension sets are returned in sorted order.
878///
879/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
880FailureOr<ConvolutionDimensions>
882 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
883 return failure();
884
885 auto indexingMaps = linalgOp.getIndexingMapsArray();
886
887 // Check the input indexing map has the right form.
888 ConvAccessExprWalker inputExprWalker;
889 for (AffineExpr expr : indexingMaps[0].getResults())
890 (void)inputExprWalker.visit(expr);
891 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
892
893 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
894 /*allowEmptyConvolvedDims=*/false);
895}
896
897namespace mlir::linalg::detail {
909} // namespace mlir::linalg::detail
910
913 Operation *op, ConvolutionDimensions *dimensions,
914 bool allowEmptyConvolvedDims) {
915 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
916 if (!linalgOp)
918 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
920
921 auto indexingMaps = linalgOp.getIndexingMapsArray();
922
923 // Check the input indexing map has the right form.
924 ConvAccessExprWalker inputExprWalker;
925 if (llvm::any_of(indexingMaps[0].getResults(),
926 [&inputExprWalker](AffineExpr expr) {
927 return failed(inputExprWalker.visit(expr));
928 })) {
930 }
931
932 // Filter and output maps must be projected permutation.
933 if (!indexingMaps[1].isProjectedPermutation() ||
934 !indexingMaps.back().isProjectedPermutation())
936
937 auto iteratorTypes = linalgOp.getIteratorTypesArray();
938
939 llvm::SmallDenseSet<int64_t> outputDims =
940 getPreservedDims(indexingMaps.back());
941 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
942 // Make sure all loops are characterized as one of:
943 // - Batch loop : present in output, as non-convolved in input, not present in
944 // filter.
945 // - Output image dimension : present in output, convolved dims in input, not
946 // present in filter.
947 // - Output channel dimension : present in output, not present in input,
948 // present in filter.
949 // - Filter loop dimension : present in filter, convolved in input, not
950 // present in output.
951 // - Input channel dimension : unconvolved in input, not present in output,
952 // present in filter.
953 // - Depth multiplier : unconvolved in input, present in output, present in
954 // filter.
955 llvm::SmallDenseSet<int64_t> allLoopDims;
956 for (auto outputExpr : indexingMaps.back().getResults()) {
957 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
958 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
959 !filterDims.count(outputDim)) {
960 // Batch dimension.
961 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
963 allLoopDims.insert(outputDim);
964 continue;
965 }
966 if (inputExprWalker.convolvedDims.count(outputDim) &&
967 !filterDims.count(outputDim)) {
968 // Output image Loop dimension.
969 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
971 allLoopDims.insert(outputDim);
972 continue;
973 }
974 if (!inputExprWalker.convolvedDims.count(outputDim) &&
975 !inputExprWalker.unConvolvedDims.count(outputDim) &&
976 filterDims.count(outputDim)) {
977 // Output channel dimension.
978 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
980 allLoopDims.insert(outputDim);
981 continue;
982 }
983 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
984 filterDims.count(outputDim)) {
985 // Depth multiplier.
986 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
988 allLoopDims.insert(outputDim);
989 continue;
990 }
992 }
993 for (auto filterExpr : indexingMaps[1].getResults()) {
994 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
995 if (outputDims.count(filterDim) &&
996 !inputExprWalker.unConvolvedDims.count(filterDim) &&
997 !inputExprWalker.convolvedDims.count(filterDim)) {
998 // Output channel dimension. This is already seen, continue;
999 continue;
1000 }
1001 if (inputExprWalker.convolvedDims.count(filterDim) &&
1002 !outputDims.count(filterDim)) {
1003 // Filter loop dimension.
1004 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1006 if (allLoopDims.count(filterDim))
1008 allLoopDims.insert(filterDim);
1009 continue;
1010 }
1011 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1012 !outputDims.count(filterDim)) {
1013 // Input channel dimension.
1014 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1016 if (allLoopDims.count(filterDim))
1018 allLoopDims.insert(filterDim);
1019 continue;
1020 }
1021 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1022 outputDims.count(filterDim)) {
1023 // Depthwise loop. Already seen.
1024 continue;
1025 }
1027 }
1028 // All loops must be covered now.
1029 if (allLoopDims.size() != linalgOp.getNumLoops())
1031
1032 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1034
1035 if (dimensions) {
1036 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1037 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1038 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1039 *dimensions = *res;
1040 }
1041
1043}
1044
1045StringRef
1047 switch (res) {
1049 return "expected a LinalgOp";
1051 return "expected op with 2 inputs and 1 output";
1053 return "unexpected input index map for convolutions";
1055 return "expected output/filter indexing maps to be projected permutations";
1057 return "unexpected loop dimension for convolution op";
1059 return "expected all iterators used to access outputs to be parallel";
1061 return "expected all iterators not used to access outputs to be reduction";
1063 return "expected convolved dim to be non-empty";
1065 return "";
1066 }
1067 llvm_unreachable("unhandled MatchConvolutionResult case");
1068}
1069
1071 bool allowEmptyConvolvedDims) {
1073 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1075}
1076
1083
1084//===----------------------------------------------------------------------===//
1085// FillOpInterface implementation
1086//===----------------------------------------------------------------------===//
1087
1088namespace {
1089enum class MatchFillResult {
1090 Success = 0,
1091 NotLinalgOp,
1092 WrongNumOperands,
1093 NotScalarInput,
1094 TypeMismatch
1095};
1096} // namespace
1097
1098static MatchFillResult isFillInterfaceImpl(Operation *op) {
1099 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1100 if (!linalgOp)
1101 return MatchFillResult::NotLinalgOp;
1102 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1103 return MatchFillResult::WrongNumOperands;
1104
1105 OpOperand *value = linalgOp.getDpsInputOperand(0);
1106 if (!linalgOp.isScalar(value))
1107 return MatchFillResult::NotScalarInput;
1108
1109 // Check that the scalar input type matches the output element type.
1110 OpOperand *output = linalgOp.getDpsInitOperand(0);
1111 Type scalarType = value->get().getType();
1112 Type outputElementType = getElementTypeOrSelf(output->get().getType());
1113 if (scalarType != outputElementType)
1114 return MatchFillResult::TypeMismatch;
1115
1116 return MatchFillResult::Success;
1117}
1118
1120 MatchFillResult res = isFillInterfaceImpl(op);
1121 if (res == MatchFillResult::NotLinalgOp)
1122 return op->emitError("expected a LinalgOp");
1123 if (res == MatchFillResult::WrongNumOperands)
1124 return op->emitError("expected op with 1 input and 1 output");
1125 if (res == MatchFillResult::NotScalarInput)
1126 return op->emitError("expected op with scalar input");
1127 if (res == MatchFillResult::TypeMismatch) {
1128 auto linalgOp = cast<linalg::LinalgOp>(op);
1129 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1130 Type outputElementType =
1131 getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1132 return op->emitOpError("expected fill value type (")
1133 << scalarType << ") to match output element type ("
1134 << outputElementType << ")";
1135 }
1136
1137 return success();
1138}
1139
1140//===----------------------------------------------------------------------===//
1141// StructuredOpInterface implementation
1142//===----------------------------------------------------------------------===//
1143
1144SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1145 Location loc) {
1147 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1148 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1149 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1150 }
1151 return res;
1152}
1153
1154SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1156 assert(!hasDynamicShape() && "expected operands to have static shapes");
1157 for (OpOperand &opOperand : getOperation()->getOpOperands())
1158 llvm::append_range(res, getShape(&opOperand));
1159 return res;
1160}
1161
1162SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1163 AffineMap map = getLoopsToShapesMap();
1164 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1165 auto viewSizes = createFlatListOfOperandDims(b, loc);
1166 SmallVector<Range, 4> res(numDims);
1167 for (unsigned idx = 0; idx < numRes; ++idx) {
1168 auto result = map.getResult(idx);
1169 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1170 if (res[d.getPosition()].offset)
1171 continue;
1172 res[d.getPosition()] =
1173 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1174 }
1175 }
1176 return res;
1177}
1178
1179/// Visitor to check if any of the given set of positions from AffineDimExprs
1180/// are used within an AffineExpr.
1182 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1183 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1184 : positions(std::move(positions)) {}
1185
1187 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1188 }
1189
1191 return positions.test(dimExpr.getPosition());
1192 }
1193
1194 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1195
1196 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1197
1198private:
1199 llvm::SmallBitVector positions;
1200};
1201
1202static std::pair<int64_t, int64_t>
1204 int64_t inputRankSum = 0;
1205 int64_t outputRankSum = 0;
1206 for (OpOperand *input : op.getDpsInputOperands())
1207 inputRankSum += op.getRank(input);
1208 for (OpOperand &output : op.getDpsInitsMutable())
1209 outputRankSum += op.getRank(&output);
1210 return {inputRankSum, inputRankSum + outputRankSum};
1211}
1212
1213LogicalResult
1214LinalgOp::reifyResultShapes(OpBuilder &b,
1215 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1216 // An example that helps understand the logic below.
1217 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1218 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1219 // This is achieved as follows.
1220 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1221 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1222 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1223 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1224 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1225 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1226
1227 // Find the position in the above map that represents the shape of the
1228 // result:dim being inferred.
1229 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1230
1231 /// From loopsToShapesMap extract the submap that represents the shape of the
1232 /// (resultIdx, dim) needed.
1233 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1234 resultShapesSubMapPos.first,
1235 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1236 AffineMap resultShapesFromInputShapesMap =
1237 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1238
1239 // Check that the result dim map does not contain the positions corresponding
1240 // to the outputs.
1241 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1242 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1243 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1244 Location loc = getOperation()->getLoc();
1245 IRRewriter rewriter(b);
1246 SmallVector<OpFoldResult> allResultDimValues =
1248 rewriter, loc, resultShapesFromInputShapesMap,
1249 createFlatListOfOperandDims(b, loc));
1250 int64_t pos = 0;
1251 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1252 for (OpOperand &opOperand : getDpsInitsMutable()) {
1253 SmallVector<OpFoldResult> shapes;
1254 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1255 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1256 if (!shapedType.isDynamicDim(dim)) {
1257 // Static dim: Return IntegerAttr.
1258 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1259 } else {
1260 // Dynamic dim: Return Value.
1261 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1262 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1263 : allResultDimValues[pos];
1264 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1265 }
1266 pos++;
1267 }
1268 reifiedReturnShapes.emplace_back(std::move(shapes));
1269 }
1270 return success();
1271}
1272
1273/// Return the index in the indexingMaps vector that corresponds to this
1274/// `opOperand`.
1275int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1276 auto operandNumber = opOperand->getOperandNumber();
1277 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1278 if (!dpsIface.isDpsInput(opOperand))
1279 return operandNumber;
1280 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1281 assert(!dpsIface.isDpsInit(opOperand));
1282 // Account for potential inputs that are not DPS and may not appear in
1283 // `indexingMaps`.
1284 return cast<DestinationStyleOpInterface>(*this->getOperation())
1285 .getNumDpsInputs() +
1286 operandNumber - start;
1287}
1288
1290 LinalgOp linalgOp = cast<LinalgOp>(op);
1291 // Mixed tensor/buffer operands are not allowed.
1292 if (!linalgOp.hasPureTensorSemantics() &&
1293 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1294 return op->emitOpError("expected to have pure tensor or buffer semantics");
1295
1296 // Before checking indexing maps, we need to make sure the attributes
1297 // referenced by it are valid.
1298 if (linalgOp.hasDynamicIndexingMaps())
1299 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1300 return failure();
1301
1302 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1303 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1304 return failure();
1305
1306 // Set this flag if this op has user defined maps. This is required to guard
1307 // the below error condition which assume default indexing maps.
1308 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1309 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1310 // Domain must be consistent.
1311 unsigned numLoops = linalgOp.getNumLoops();
1312 if (indexingMap.getNumDims() != numLoops)
1313 return op->emitOpError("expected indexing_map #")
1314 << opOperand.getOperandNumber() << " to have " << numLoops
1315 << " dim(s) to match the number of loops";
1316 }
1317 SmallVector<unsigned> redDims;
1318 linalgOp.getReductionDims(redDims);
1319
1320 if (!linalgOp.getShapesToLoopsMap())
1321 return op->emitOpError("expected the shape-to-loops map to be non-null");
1322
1323 // Check the region has exactly one block.
1324 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1325 return op->emitOpError("expects to have 1 region with 1 block");
1326
1327 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1328 // types.
1329 // TODO: once ranked shape types are plugged in, we may want to drop the
1330 // corresponding bbargs, that can never be read from. This will be subject to
1331 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1332 // not used).
1333 Block &block = linalgOp->getRegion(0).front();
1334
1335 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1336 return op->emitOpError("expected as many non-induction variable region "
1337 "arguments as the number of input/output operands");
1338
1339 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1340 Type elementType = opOperand->get().getType();
1341 if (isa<MemRefType, RankedTensorType>(elementType))
1342 elementType = getElementTypeOrSelf(opOperand->get().getType());
1343 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1344 if (elementType != argType)
1345 return op->emitOpError("expected type of bb argument #")
1346 << opOperand->getOperandNumber() << " (" << argType << ")"
1347 << " to match element or self type of the corresponding operand ("
1348 << elementType << ")";
1349 }
1350
1351 return success();
1352}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:786
unsigned getNumOperands()
Definition Operation.h:375
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop