MLIR 23.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 if (auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
143 return SmallVector<int64_t>(broadcastOp.getDimensions().begin(),
144 broadcastOp.getDimensions().end());
145
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
147 if (!op)
148 return std::nullopt;
149
150 // Structural.
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
153 return std::nullopt;
154
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
159 return std::nullopt;
160
161 // Check output is identity map. Broadcast could additionally be
162 // employing permutation of indices and that would be expressible
163 // in linalg.generic but is not expressible for named broadcast op.
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
166 return std::nullopt;
167
168 SmallVector<int64_t> position;
169 auto srcMap = op.getIndexingMapsArray()[0];
170
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
172 return std::nullopt;
173
174 // Check input map is monotonically increasing DimIds.
175 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
177 if (!expr)
178 return std::nullopt;
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
181 return std::nullopt;
182 position.push_back(expr.getPosition());
183 }
184
185 SmallVector<int64_t> broadcastedDims;
186 auto numDims = srcMap.getNumDims();
187 // This is quadratic but number of items is generally small.
188 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
191 }
192 return broadcastedDims;
193}
194
195//===----------------------------------------------------------------------===//
196// TransposeOpInterface implementation
197//===----------------------------------------------------------------------===//
198std::optional<SmallVector<int64_t>>
200 // To specialize as a transpose op, the genericOp must be
201 // all parallel loops, single input, single output, and its body
202 // should be just a yield op, yielding input as output as is (no compute).
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
205 return std::nullopt;
206
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
209 return std::nullopt;
210
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
213
214 // linalg.transpose permutes the dimensions of input using this
215 // rule: dim(result, i) = dim(input, permutation[i])
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
217 return std::nullopt;
218
219 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
220 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
223 }
224 return permutation;
225}
226
227//===----------------------------------------------------------------------===//
228// Elementwise Single Unary/Binary-OpInterface implementation
229//===----------------------------------------------------------------------===//
230static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
231 unsigned arity) {
232 // Check all loops are parallel.
233 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
234 return false;
235
236 // Check there are arity-inputs, 1-output and all are identity-maps.
237 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
238 !llvm::all_of(op.getIndexingMapsArray(),
239 [](AffineMap map) { return map.isIdentity(); }))
240 return false;
241
242 // Init should not be referenced for elementwise operations.
243 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
244 return false;
245
246 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
247 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
248 // the body, where the first is the elementwise single op and the second a
249 // yield.
250 Block *body = op.getBody();
251 if (body->getOperations().size() != 2)
252 return false;
253
254 Operation *oper = &body->front();
255 if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
256 return false;
257
258 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
259 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
260 yieldOp->getOperand(0).getDefiningOp() != oper);
261}
262
263bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
264 // All basic elemwise checks.
266 return false;
267
268 // Check input is actully used.
269 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
270 return false;
271 return true;
272}
273
274bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
276 return false;
277
278 // Check both inputs are used (elementwise).
279 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
280 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
281 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
282 !op.payloadUsesValueFromOperand(inputOpOperand1));
283}
284
285//===----------------------------------------------------------------------===//
286// ContractionOpInterface implementation
287//===----------------------------------------------------------------------===//
288
289/// If the value is defined by a chain of unary side effect-free, go up the
290/// use-def chain until the first value that isn't defined by such an op.
291// TODO: relax to multi-operands with constants, which are technically unary ops
292// as needed (e.g. add5).
294 Operation *op = value.getDefiningOp();
295 while (op && op->getNumOperands() == 1) {
296 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
297 if (!iface || !iface.hasNoEffect())
298 break;
299 value = op->getOperand(0);
300 op = value.getDefiningOp();
301 }
302 return value;
303}
304
306 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
307 llvm::raw_ostream &errs) {
308 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
309 errs << "no terminator in the block";
310 return false;
311 }
312
313 if (block.getNumArguments() != 3) {
314 errs << "expected block with 3 arguments";
315 return false;
316 }
317
318 Operation *terminator = block.getTerminator();
319 if (terminator->getNumOperands() != 1) {
320 errs << "expected terminator with 1 operand";
321 return false;
322 }
323
324 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
325 Operation *reductionOp = yielded.getDefiningOp();
326 if (!reductionOp || reductionOp->getNumResults() != 1 ||
327 reductionOp->getNumOperands() != 2) {
328 errs << "expected reduction op to be binary";
329 return false;
330 }
331
332 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
333 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
334
335 if (reductionLHS != block.getArgument(2) &&
336 reductionRHS != block.getArgument(2)) {
337 errs << "expected reduction to take block argument #2 as one of the "
338 "operands (modulo unary casts)";
339 return false;
340 }
341
342 Value contributed = getSourceSkipUnary(
343 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
344 Operation *elementwiseOp = contributed.getDefiningOp();
345 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
346 elementwiseOp->getNumOperands() != 2) {
347 errs << "expected elementwise op to be binary";
348 return false;
349 }
350
351 if (!isaPair(elementwiseOp, reductionOp)) {
352 errs << "expected reduction/elementwise op kind not satisfied";
353 return false;
354 }
355
356 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
357 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
358 if ((elementwiseLHS == block.getArgument(0) &&
359 elementwiseRHS == block.getArgument(1)) ||
360 (elementwiseLHS == block.getArgument(1) &&
361 elementwiseRHS == block.getArgument(0))) {
362 return true;
363 }
364
365 errs << "expected elementwise op to apply to block arguments (modulo unary "
366 "casts)";
367 return false;
368}
369
370/// Returns true if the two operations are of the kinds specified by a pair of
371/// consecutive template arguments.
372template <typename AddOpTy, typename MulOpTy, typename... Args>
374 static_assert(sizeof...(Args) % 2 == 0,
375 "expected an even number of template arguments");
376 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
377 return true;
378
379 if constexpr (sizeof...(Args) > 0)
381 else
382 return false;
383}
384
385/// Returns true if the block is a body of a contraction with the kinds of
386/// operations given pairwise by template arguments.
387template <typename... Args>
391
392/// Given an `indexingMap` and its corresponding `iterators`, returns
393/// the positions of the iterators of type `iter` that are indexed by
394/// the `indexingMap` as a permutation. This is useful to infer various
395/// subcomputations on a `LinalgOp`. This is performed by looking up
396/// each result in the `indexingMap` and determining whether:
397/// - It is a single AffineDimExpr.
398/// - It is the only result involving this AffineDimExpr.
399static llvm::SmallDenseSet<int64_t>
402 utils::IteratorType iter) {
403 assert(iterators.size() == indexingMap.getNumDims());
404 llvm::SmallDenseSet<int64_t> res;
405 for (AffineExpr e : indexingMap.getResults()) {
406 if (auto d = dyn_cast<AffineDimExpr>(e)) {
407 if (iterators[d.getPosition()] == iter &&
408 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
409 return e.isFunctionOfDim(d.getPosition());
410 }) == 1)
411 res.insert(d.getPosition());
412 }
413 }
414 return res;
415}
416
417namespace {
418auto par = utils::IteratorType::parallel;
419auto red = utils::IteratorType::reduction;
420} // namespace
421
422/// Infer the iterator types from the init affine map. This looks at which dims
423/// are present in the map results, and returns an iterator types array with
424/// parallel types for dims that are present, and reduction types for dims that
425/// are not present.
426static FailureOr<SmallVector<utils::IteratorType>>
428 if (!map.isProjectedPermutation())
429 return failure();
430 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
431 for (auto expr : map.getResults())
432 if (auto dim = dyn_cast<AffineDimExpr>(expr))
433 iterators[dim.getPosition()] = par;
434 return iterators;
435}
436
437/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
438/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
439/// 1. The m dimension is involved in an outer-product along LHS
440/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
441/// 2. The n dimension is involved in an outer-product along RHS
442/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
443/// 3. The k dimension appears as a permutation on LHS and RHS.
444/// 4. m, n and k appear only once in any given indexing.
445/// 5. Optional batch dimensions that appear in all operands are captured.
446/// This allows e.g. detecting that some contraction is embedded within
447/// `linalgOp` with some orthogonal heuristic.
448static FailureOr<ContractionDimensions>
451 llvm::SmallDenseSet<int64_t> a =
452 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
453 llvm::SmallDenseSet<int64_t> b =
454 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
455 llvm::SmallDenseSet<int64_t> c =
456 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
457
458 // A & C - B are the iterators involved in an outer-product along A (the LHS).
459 llvm::SmallDenseSet<int64_t> ac = a;
460 llvm::set_intersect(ac, c);
461 llvm::set_subtract(ac, b);
462 // B & C - A are the iterators involved in an outer-product along B (the RHS).
463 llvm::SmallDenseSet<int64_t> bc = b;
464 llvm::set_intersect(bc, c);
465 llvm::set_subtract(bc, a);
466 // A & B & C are the "batch" dimensions.
467 llvm::SmallDenseSet<int64_t> batches = a;
468 llvm::set_intersect(batches, b);
469 llvm::set_intersect(batches, c);
470
471 // A & B red are the reduction dimensions.
472 llvm::SmallDenseSet<int64_t> ra =
473 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
474 llvm::SmallDenseSet<int64_t> rb =
475 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
476 llvm::set_intersect(ra, rb);
477
478 // Return each set in sorted order.
479 ContractionDimensions dimensions{
480 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
481 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
482 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
483 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
484 llvm::sort(dimensions.batch);
485 llvm::sort(dimensions.m);
486 llvm::sort(dimensions.n);
487 llvm::sort(dimensions.k);
488 return dimensions;
489}
490
491FailureOr<ContractionDimensions>
493 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
494 return failure();
495 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
496 linalgOp.getIteratorTypesArray());
497}
498
499FailureOr<ContractionDimensions>
501 if (indexingMaps.size() != 3)
502 return failure();
503 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
504 if (failed(iterators))
505 return failure();
506 return inferContractionDimsImpl(indexingMaps, iterators.value());
507}
508
509namespace mlir::linalg::detail {
518} // namespace mlir::linalg::detail
519
523 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
524 if (!linalgOp)
526 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
528 auto mapRange = linalgOp.getIndexingMapsArray();
529 if (linalgOp.getNumReductionLoops() == 0)
531 if (llvm::any_of(mapRange,
532 [](AffineMap m) { return !m.isProjectedPermutation(); }))
534 // TODO: more fields than add/mul.
535 // clang-format off
537 arith::MulFOp, arith::AddFOp,
538 arith::MulIOp, arith::AddIOp,
539 complex::MulOp, complex::AddOp,
540 arith::AndIOp, arith::OrIOp>(
541 *linalgOp.getBlock())) {
543 }
544 // clang-format on
545
546 if (dimensions) {
547 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
548 assert(succeeded(res) && "unexpected failure to infer contraction dims");
549 *dimensions = *res;
550 }
552}
553
554StringRef
556 switch (res) {
558 return "expected a LinalgOp";
560 return "expected op with 2 inputs and 1 output";
562 return "expected at least 1 reduction";
564 return "expected indexing maps to be projected permutations";
566 return "expected add/mul op in the body";
568 return "";
569 }
570 llvm_unreachable("unhandled MatchContractionResult case");
571}
572
574 if (!linalgOp)
575 return false;
576 Operation *op = linalgOp.getOperation();
577 return isa<ContractionOpInterface>(op) ||
580}
581
582/// Verify that a LinalgOp `op` is a contraction.
583/// A Linalg contraction is defined in general terms:
584/// 1. Has 2 input and 1 output shapes.
585/// 2. Has at least one reduction dimension.
586/// 3. Has only projected permutation indexing maps.
587/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
588/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
589/// operations that may change the type (e.g. for mixed-precision).
590/// As a consequence, when vectorization of such an op occurs, the only special
591/// behavior is that the (unique) MulOpType is vectorized into a
592/// `vector.contract`. All other ops are handled in a generic fashion.
593/// In the future, we may wish to allow more input arguments and elementwise and
594/// constant operations that do not involve the reduction dimension(s).
601
602//===----------------------------------------------------------------------===//
603// ConvolutionOpInterface implementation
604//===----------------------------------------------------------------------===//
605
606/// Of the given two expressions returns one that is of type T (`lhs` gets
607/// preference over `rhs`)
608template <typename T>
610 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
611}
612
613namespace {
614/// Walk the indexing expressions for input of a convolution operation to verify
615/// its of the right form, either
616/// - AffineDimExpr
617/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
618/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
619///
620/// classifies the AffineDimExpr as convolved dimensions or unconvolved
621/// dimensions and verifies each dimension occurs only once.
622struct ConvAccessExprWalker
623 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
624 // Stores dimensions used in expressions of the above form.
625 llvm::SmallDenseSet<int64_t> convolvedDims;
626 // Stores the dual mapping between LHS and RHS of convolution exprs.
627 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
628 // Stores single use dimensions used by an AffineDimExpr.
629 llvm::SmallDenseSet<int64_t> unConvolvedDims;
630 // Stores a mapping from convolved dims to their coefficient.
631 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
632
633 // Removes dims with multiple uses in the source input map from dimension
634 // sets tracked by this walker.
635 void clearMultiUseDims(AffineMap map) {
636 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
637 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
638 return e.isFunctionOfDim(dimPos);
639 }) > 1) {
640 convolvedDims.erase(dimPos);
641 unConvolvedDims.erase(dimPos);
642 // If a duplicate dim is marked as convolved, the pair of the duplicate
643 // dim must be removed from the map as well.
644 auto it = convolvedDimMapping.find(dimPos);
645 if (it != convolvedDimMapping.end()) {
646 int64_t pairedDim = it->second;
647 convolvedDims.erase(pairedDim);
648 unConvolvedDims.erase(pairedDim);
649 strideAndDilationMapping.erase(pairedDim);
650 convolvedDimMapping.erase(dimPos);
651 convolvedDimMapping.erase(pairedDim);
652 }
653 }
654 }
655 }
656
657 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
658 unsigned position = dimExpr.getPosition();
659 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
660 return failure();
661 }
662 unConvolvedDims.insert(position);
663 return success();
664 }
665
666 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
667
668 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
669
670 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
671 // In pre-order visit, top level op has to be an add op.
672 if (binaryExpr.getKind() != AffineExprKind::Add)
673 return failure();
674 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
675 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
676 if (failed(lhsDimPos) || failed(rhsDimPos))
677 return failure();
678 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
679 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
680 return success();
681 }
682
683 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
684 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
685 int64_t dim = dimExpr.getPosition();
686 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
687 return failure();
688 // Stride/dilation for this dim is implicitly 1.
689 strideAndDilationMapping[dim] =
691 convolvedDims.insert(dim);
692 return dim;
693 }
694 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
695 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
696 return failure();
697 auto lhsExpr = symbolMulExpr.getLHS();
698 auto rhsExpr = symbolMulExpr.getRHS();
699 // Check for symbol expression.
700 AffineExpr mulExpr =
702 // If there was no symbol expr, check for constant expression.
703 if (!mulExpr) {
704 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
705 }
706 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
707 if (!mulExpr || !dimExpr)
708 return failure();
709 int64_t dim = dimExpr.getPosition();
710 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
711 return failure();
712 strideAndDilationMapping[dim] = mulExpr;
713 convolvedDims.insert(dim);
714 return dim;
715 }
716 return failure();
717 }
718};
719} // namespace
720
721static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
722 assert(map.isProjectedPermutation() &&
723 "expected map to have projected permutations");
724 llvm::SmallDenseSet<int64_t> preservedDims;
725 for (auto expr : map.getResults())
726 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
727 return preservedDims;
728}
729
733 for (auto e : exprs) {
734 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
735 assert(constantExpr && "Found non-constant stride/dilation");
736 vals.push_back(constantExpr.getValue());
737 }
738 return vals;
739}
740
741/// Classifies dimensions in the `linalgOp` used by a convolution
742/// subcomputation, as captured by `inputExprWalker`. If
743/// `allowEmptyConvolvedDims` is not set this will fail if there is not
744/// at least one convolved dimension pair (output image + filter loop).
745///
746/// The returned dimensions are ordered as follows:
747/// - `outputImage` is sorted by dimension index.
748/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
749/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
750/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
751/// - `strides[i]` corresponds to `outputImage[i]`.
752/// - `dilations[i]` corresponds to `filterLoop[i]`.
753/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
754static FailureOr<ConvolutionDimensions>
755inferConvolutionDimsImpl(LinalgOp linalgOp,
756 ConvAccessExprWalker &inputExprWalker,
757 bool allowEmptyConvolvedDims) {
758 auto filterMap =
759 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
760 auto outputMap =
761 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
762 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
763 filterMap, linalgOp.getIteratorTypesArray(), par);
764 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
765 outputMap, linalgOp.getIteratorTypesArray(), par);
766
767 // unConvolvedDims & outputDims - filterDims are the batch iterators.
768 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
769 llvm::set_intersect(batch, outputDims);
770 llvm::set_subtract(batch, filterDims);
771
772 // convolvedDims & outputDims are the output image iterators.
773 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
774 llvm::set_intersect(oi, outputDims);
775
776 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
777 llvm::SmallDenseSet<int64_t> oc = filterDims;
778 llvm::set_intersect(oc, outputDims);
779 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
780
781 // filterDims & outputDims & unConvolvedDims are the depth iterators.
782 llvm::SmallDenseSet<int64_t> depth = filterDims;
783 llvm::set_intersect(depth, outputDims);
784 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
785
786 llvm::SmallDenseSet<int64_t> filterReducedDims =
788 linalgOp.getIteratorTypesArray(), red);
789
790 // convolvedDims & filterReducedDims are the filter loop iterators.
791 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
792 llvm::set_intersect(fl, filterReducedDims);
793
794 // unConvolvedDims & filterReducedDims are the input channel iterators.
795 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
796 llvm::set_intersect(ic, filterReducedDims);
797
798 if (oi.empty() && !allowEmptyConvolvedDims)
799 return failure();
800
801 // Return each set in sorted order, with outputImage and filterLoop
802 // ordered so that outputImage[i] pairs with filterLoop[i].
803 ConvolutionDimensions dimensions{
804 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
805 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
806 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
807 /*filterLoop=*/SmallVector<unsigned, 2>{},
808 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
809 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
810 /*strides=*/SmallVector<int64_t, 2>{},
811 /*dilations=*/SmallVector<int64_t, 2>{}};
812 llvm::sort(dimensions.batch);
813 llvm::sort(dimensions.outputImage);
814 llvm::sort(dimensions.outputChannel);
815 llvm::sort(dimensions.inputChannel);
816 llvm::sort(dimensions.depth);
817 // Order filterLoop to match the pairing with outputImage. Each outputImage
818 // dimension has a corresponding filterLoop dimension from the convolution
819 // access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
820 // filterLoop[i].
821 for (unsigned oiDim : dimensions.outputImage)
822 dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
823
824 // Use the op carried strides/dilations attribute if present.
825 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
826 if (!nativeStrides) {
827 SmallVector<AffineExpr, 2> strideExprs;
828 for (unsigned oiDim : dimensions.outputImage)
829 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
830 dimensions.strides = getConstantsFromExprList(strideExprs);
831 } else {
832 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
833 }
834 auto nativeDilations =
835 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
836 if (!nativeDilations) {
837 SmallVector<AffineExpr, 2> dilationExprs;
838 for (unsigned flDim : dimensions.filterLoop)
839 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
840 dimensions.dilations = getConstantsFromExprList(dilationExprs);
841 } else {
842 dimensions.dilations =
843 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
844 }
845 return dimensions;
846}
847
848/// Find at least 1 parallel (output_image) and reduction (filter_loop)
849/// dimension candidates that form a convolution subcomputation within
850/// `linalgOp`. The LHS is assumed to be the convolution input while the
851/// RHS is assumed as the filter.
852/// These dimensions are such that:
853/// 1. Optional batch dimensions that appear in the input and filter.
854/// 2. The output_image dimension is involved in a cross-correlation along LHS
855/// (i.e. it is a permutation on RES and LHS and has an associated
856/// filter_loop in RHS).
857/// 3. Optional output_channel dimension is involved in an outer-product along
858/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
859/// LHS).
860/// 4. Optional input_channel dimension appears as a permutation on LHS and
861/// RHS.
862/// 5. The filter_loop dimension appears as a permutation on the RHS and
863/// represents the shape of the kernel cross-correlated along a
864/// corresponding output_image dim.
865/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
866/// 7. All dimensions appear only once in any given indexing map.
867/// This allows e.g. detecting that some convolution is embedded within
868/// `linalgOp` with some orthogonal heuristic.
869///
870/// The `outputImage` and `filterLoop` arrays are ordered such that
871/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
872/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
873/// dimension 2). Other dimension sets are returned in sorted order.
874///
875/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
876FailureOr<ConvolutionDimensions>
878 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
879 return failure();
880
881 auto indexingMaps = linalgOp.getIndexingMapsArray();
882
883 // Check the input indexing map has the right form.
884 ConvAccessExprWalker inputExprWalker;
885 for (AffineExpr expr : indexingMaps[0].getResults())
886 (void)inputExprWalker.visit(expr);
887 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
888
889 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
890 /*allowEmptyConvolvedDims=*/false);
891}
892
893namespace mlir::linalg::detail {
905} // namespace mlir::linalg::detail
906
909 Operation *op, ConvolutionDimensions *dimensions,
910 bool allowEmptyConvolvedDims) {
911 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
912 if (!linalgOp)
914 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
916
917 auto indexingMaps = linalgOp.getIndexingMapsArray();
918
919 // Check the input indexing map has the right form.
920 ConvAccessExprWalker inputExprWalker;
921 if (llvm::any_of(indexingMaps[0].getResults(),
922 [&inputExprWalker](AffineExpr expr) {
923 return failed(inputExprWalker.visit(expr));
924 })) {
926 }
927
928 // Filter and output maps must be projected permutation.
929 if (!indexingMaps[1].isProjectedPermutation() ||
930 !indexingMaps.back().isProjectedPermutation())
932
933 auto iteratorTypes = linalgOp.getIteratorTypesArray();
934
935 llvm::SmallDenseSet<int64_t> outputDims =
936 getPreservedDims(indexingMaps.back());
937 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
938 // Make sure all loops are characterized as one of:
939 // - Batch loop : present in output, as non-convolved in input, not present in
940 // filter.
941 // - Output image dimension : present in output, convolved dims in input, not
942 // present in filter.
943 // - Output channel dimension : present in output, not present in input,
944 // present in filter.
945 // - Filter loop dimension : present in filter, convolved in input, not
946 // present in output.
947 // - Input channel dimension : unconvolved in input, not present in output,
948 // present in filter.
949 // - Depth multiplier : unconvolved in input, present in output, present in
950 // filter.
951 llvm::SmallDenseSet<int64_t> allLoopDims;
952 for (auto outputExpr : indexingMaps.back().getResults()) {
953 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
954 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
955 !filterDims.count(outputDim)) {
956 // Batch dimension.
957 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
959 allLoopDims.insert(outputDim);
960 continue;
961 }
962 if (inputExprWalker.convolvedDims.count(outputDim) &&
963 !filterDims.count(outputDim)) {
964 // Output image Loop dimension.
965 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
967 allLoopDims.insert(outputDim);
968 continue;
969 }
970 if (!inputExprWalker.convolvedDims.count(outputDim) &&
971 !inputExprWalker.unConvolvedDims.count(outputDim) &&
972 filterDims.count(outputDim)) {
973 // Output channel dimension.
974 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
976 allLoopDims.insert(outputDim);
977 continue;
978 }
979 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
980 filterDims.count(outputDim)) {
981 // Depth multiplier.
982 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
984 allLoopDims.insert(outputDim);
985 continue;
986 }
988 }
989 for (auto filterExpr : indexingMaps[1].getResults()) {
990 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
991 if (outputDims.count(filterDim) &&
992 !inputExprWalker.unConvolvedDims.count(filterDim) &&
993 !inputExprWalker.convolvedDims.count(filterDim)) {
994 // Output channel dimension. This is already seen, continue;
995 continue;
996 }
997 if (inputExprWalker.convolvedDims.count(filterDim) &&
998 !outputDims.count(filterDim)) {
999 // Filter loop dimension.
1000 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1002 if (allLoopDims.count(filterDim))
1004 allLoopDims.insert(filterDim);
1005 continue;
1006 }
1007 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1008 !outputDims.count(filterDim)) {
1009 // Input channel dimension.
1010 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1012 if (allLoopDims.count(filterDim))
1014 allLoopDims.insert(filterDim);
1015 continue;
1016 }
1017 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1018 outputDims.count(filterDim)) {
1019 // Depthwise loop. Already seen.
1020 continue;
1021 }
1023 }
1024 // All loops must be covered now.
1025 if (allLoopDims.size() != linalgOp.getNumLoops())
1027
1028 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1030
1031 if (dimensions) {
1032 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1033 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1034 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1035 *dimensions = *res;
1036 }
1037
1039}
1040
1041StringRef
1043 switch (res) {
1045 return "expected a LinalgOp";
1047 return "expected op with 2 inputs and 1 output";
1049 return "unexpected input index map for convolutions";
1051 return "expected output/filter indexing maps to be projected permutations";
1053 return "unexpected loop dimension for convolution op";
1055 return "expected all iterators used to access outputs to be parallel";
1057 return "expected all iterators not used to access outputs to be reduction";
1059 return "expected convolved dim to be non-empty";
1061 return "";
1062 }
1063 llvm_unreachable("unhandled MatchConvolutionResult case");
1064}
1065
1067 bool allowEmptyConvolvedDims) {
1069 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1071}
1072
1079
1080//===----------------------------------------------------------------------===//
1081// FillOpInterface implementation
1082//===----------------------------------------------------------------------===//
1083
1084namespace {
1085enum class MatchFillResult {
1086 Success = 0,
1087 NotLinalgOp,
1088 WrongNumOperands,
1089 NotScalarInput,
1090 TypeMismatch
1091};
1092} // namespace
1093
1094static MatchFillResult isFillInterfaceImpl(Operation *op) {
1095 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1096 if (!linalgOp)
1097 return MatchFillResult::NotLinalgOp;
1098 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1099 return MatchFillResult::WrongNumOperands;
1100
1101 OpOperand *value = linalgOp.getDpsInputOperand(0);
1102 if (!linalgOp.isScalar(value))
1103 return MatchFillResult::NotScalarInput;
1104
1105 // Check that the scalar input type matches the output element type.
1106 OpOperand *output = linalgOp.getDpsInitOperand(0);
1107 Type scalarType = value->get().getType();
1108 Type outputElementType = getElementTypeOrSelf(output->get().getType());
1109 if (scalarType != outputElementType)
1110 return MatchFillResult::TypeMismatch;
1111
1112 return MatchFillResult::Success;
1113}
1114
1116 MatchFillResult res = isFillInterfaceImpl(op);
1117 if (res == MatchFillResult::NotLinalgOp)
1118 return op->emitError("expected a LinalgOp");
1119 if (res == MatchFillResult::WrongNumOperands)
1120 return op->emitError("expected op with 1 input and 1 output");
1121 if (res == MatchFillResult::NotScalarInput)
1122 return op->emitError("expected op with scalar input");
1123 if (res == MatchFillResult::TypeMismatch) {
1124 auto linalgOp = cast<linalg::LinalgOp>(op);
1125 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1126 Type outputElementType =
1127 getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1128 return op->emitOpError("expected fill value type (")
1129 << scalarType << ") to match output element type ("
1130 << outputElementType << ")";
1131 }
1132
1133 return success();
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// StructuredOpInterface implementation
1138//===----------------------------------------------------------------------===//
1139
1140SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1141 Location loc) {
1143 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1144 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1145 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1146 }
1147 return res;
1148}
1149
1150SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1152 assert(!hasDynamicShape() && "expected operands to have static shapes");
1153 for (OpOperand &opOperand : getOperation()->getOpOperands())
1154 llvm::append_range(res, getShape(&opOperand));
1155 return res;
1156}
1157
1158SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1159 AffineMap map = getLoopsToShapesMap();
1160 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1161 auto viewSizes = createFlatListOfOperandDims(b, loc);
1162 SmallVector<Range, 4> res(numDims);
1163 for (unsigned idx = 0; idx < numRes; ++idx) {
1164 auto result = map.getResult(idx);
1165 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1166 if (res[d.getPosition()].offset)
1167 continue;
1168 res[d.getPosition()] =
1169 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1170 }
1171 }
1172 return res;
1173}
1174
1175/// Visitor to check if any of the given set of positions from AffineDimExprs
1176/// are used within an AffineExpr.
1178 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1179 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1180 : positions(std::move(positions)) {}
1181
1183 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1184 }
1185
1187 return positions.test(dimExpr.getPosition());
1188 }
1189
1190 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1191
1192 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1193
1194private:
1195 llvm::SmallBitVector positions;
1196};
1197
1198static std::pair<int64_t, int64_t>
1200 int64_t inputRankSum = 0;
1201 int64_t outputRankSum = 0;
1202 for (OpOperand *input : op.getDpsInputOperands())
1203 inputRankSum += op.getRank(input);
1204 for (OpOperand &output : op.getDpsInitsMutable())
1205 outputRankSum += op.getRank(&output);
1206 return {inputRankSum, inputRankSum + outputRankSum};
1207}
1208
1209LogicalResult
1210LinalgOp::reifyResultShapes(OpBuilder &b,
1211 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1212 // An example that helps understand the logic below.
1213 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1214 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1215 // This is achieved as follows.
1216 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1217 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1218 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1219 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1220 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1221 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1222
1223 // Find the position in the above map that represents the shape of the
1224 // result:dim being inferred.
1225 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1226
1227 /// From loopsToShapesMap extract the submap that represents the shape of the
1228 /// (resultIdx, dim) needed.
1229 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1230 resultShapesSubMapPos.first,
1231 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1232 AffineMap resultShapesFromInputShapesMap =
1233 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1234
1235 // Check that the result dim map does not contain the positions corresponding
1236 // to the outputs.
1237 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1238 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1239 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1240 Location loc = getOperation()->getLoc();
1241 IRRewriter rewriter(b);
1242 SmallVector<OpFoldResult> allResultDimValues =
1243 affine::makeComposedFoldedMultiResultAffineApply(
1244 rewriter, loc, resultShapesFromInputShapesMap,
1245 createFlatListOfOperandDims(b, loc));
1246 int64_t pos = 0;
1247 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1248 for (OpOperand &opOperand : getDpsInitsMutable()) {
1249 SmallVector<OpFoldResult> shapes;
1250 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1251 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1252 if (!shapedType.isDynamicDim(dim)) {
1253 // Static dim: Return IntegerAttr.
1254 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1255 } else {
1256 // Dynamic dim: Return Value.
1257 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1258 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1259 : allResultDimValues[pos];
1260 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1261 }
1262 pos++;
1263 }
1264 reifiedReturnShapes.emplace_back(std::move(shapes));
1265 }
1266 return success();
1267}
1268
1269/// Return the index in the indexingMaps vector that corresponds to this
1270/// `opOperand`.
1271int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1272 auto operandNumber = opOperand->getOperandNumber();
1273 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1274 if (!dpsIface.isDpsInput(opOperand))
1275 return operandNumber;
1276 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1277 assert(!dpsIface.isDpsInit(opOperand));
1278 // Account for potential inputs that are not DPS and may not appear in
1279 // `indexingMaps`.
1280 return cast<DestinationStyleOpInterface>(*this->getOperation())
1281 .getNumDpsInputs() +
1282 operandNumber - start;
1283}
1284
1286 LinalgOp linalgOp = cast<LinalgOp>(op);
1287 // Mixed tensor/buffer operands are not allowed.
1288 if (!linalgOp.hasPureTensorSemantics() &&
1289 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1290 return op->emitOpError("expected to have pure tensor or buffer semantics");
1291
1292 // Before checking indexing maps, we need to make sure the attributes
1293 // referenced by it are valid.
1294 if (linalgOp.hasDynamicIndexingMaps())
1295 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1296 return failure();
1297
1298 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1299 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1300 return failure();
1301
1302 // Set this flag if this op has user defined maps. This is required to guard
1303 // the below error condition which assume default indexing maps.
1304 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1305 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1306 // Domain must be consistent.
1307 unsigned numLoops = linalgOp.getNumLoops();
1308 if (indexingMap.getNumDims() != numLoops)
1309 return op->emitOpError("expected indexing_map #")
1310 << opOperand.getOperandNumber() << " to have " << numLoops
1311 << " dim(s) to match the number of loops";
1312 }
1313 SmallVector<unsigned> redDims;
1314 linalgOp.getReductionDims(redDims);
1315
1316 if (!linalgOp.getShapesToLoopsMap())
1317 return op->emitOpError("expected the shape-to-loops map to be non-null");
1318
1319 // Check the region has exactly one block.
1320 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1321 return op->emitOpError("expects to have 1 region with 1 block");
1322
1323 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1324 // types.
1325 // TODO: once ranked shape types are plugged in, we may want to drop the
1326 // corresponding bbargs, that can never be read from. This will be subject to
1327 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1328 // not used).
1329 Block &block = linalgOp->getRegion(0).front();
1330
1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1332 return op->emitOpError("expected as many non-induction variable region "
1333 "arguments as the number of input/output operands");
1334
1335 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1336 Type elementType = opOperand->get().getType();
1337 if (isa<MemRefType, RankedTensorType>(elementType))
1338 elementType = getElementTypeOrSelf(opOperand->get().getType());
1339 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1340 if (elementType != argType)
1341 return op->emitOpError("expected type of bb argument #")
1342 << opOperand->getOperandNumber() << " (" << argType << ")"
1343 << " to match element or self type of the corresponding operand ("
1344 << elementType << ")";
1345 }
1346
1347 return success();
1348}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:786
unsigned getNumOperands()
Definition Operation.h:375
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop