MLIR 22.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 // Structural.
143 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144 !op.isSingleYieldOp())
145 return std::nullopt;
146
147 auto srcTy = op.getDpsInputOperand(0)->get().getType();
148 auto dstTy = op.getDpsInitOperand(0)->get().getType();
149 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150 !isa<MemRefType, RankedTensorType>(dstTy))
151 return std::nullopt;
152
153 // Check output is identity map. Broadcast could additionally be
154 // employing permutation of indices and that would be expressible
155 // in linalg.generic but is not expressible for named broadcast op.
156 auto dstMap = op.getIndexingMapsArray()[1];
157 if (!dstMap.isIdentity())
158 return std::nullopt;
159
160 SmallVector<int64_t> position;
161 auto srcMap = op.getIndexingMapsArray()[0];
162
163 if (srcMap.getResults().size() >= dstMap.getResults().size())
164 return std::nullopt;
165
166 // Check input map is monotonically increasing DimIds.
167 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
169 if (!expr)
170 return std::nullopt;
171 int64_t pos = expr.getPosition();
172 if (i > 0 && pos <= position[i - 1])
173 return std::nullopt;
174 position.push_back(expr.getPosition());
175 }
176
177 SmallVector<int64_t> broadcastedDims;
178 auto numDims = srcMap.getNumDims();
179 // This is quadratic but number of items is generally small.
180 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
181 if (!llvm::is_contained(position, dim))
182 broadcastedDims.push_back(dim);
183 }
184 return broadcastedDims;
185}
186
187//===----------------------------------------------------------------------===//
188// TransposeOpInterface implementation
189//===----------------------------------------------------------------------===//
190std::optional<SmallVector<int64_t>>
192 // To specialize as a transpose op, the genericOp must be
193 // all parallel loops, single input, single output, and its body
194 // should be just a yield op, yielding input as output as is (no compute).
195 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196 !op.isSingleYieldOp())
197 return std::nullopt;
198
199 auto mapRange = op.getIndexingMapsArray();
200 if (mapRange.size() != 2)
201 return std::nullopt;
202
203 auto mapOfInput = mapRange.front();
204 auto mapOfResult = mapRange.back();
205
206 // linalg.transpose permutes the dimensions of input using this
207 // rule: dim(result, i) = dim(input, permutation[i])
208 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
209 return std::nullopt;
210
211 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
212 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214 permutation[expr.getPosition()] = i;
215 }
216 return permutation;
217}
218
219//===----------------------------------------------------------------------===//
220// Elementwise Single Unary/Binary-OpInterface implementation
221//===----------------------------------------------------------------------===//
222static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
223 unsigned arity) {
224 // Check all loops are parallel.
225 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
226 return false;
227
228 // Check there are arity-inputs, 1-output and all are identity-maps.
229 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230 !llvm::all_of(op.getIndexingMapsArray(),
231 [](AffineMap map) { return map.isIdentity(); }))
232 return false;
233
234 // Init should not be referenced for elementwise operations.
235 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
236 return false;
237
238 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
239 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
240 // the body, where the first is the elementwise single op and the second a
241 // yield.
242 Block *body = op.getBody();
243 if (body->getOperations().size() != 2)
244 return false;
245
246 Operation *oper = &body->front();
247 if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
248 return false;
249
250 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
251 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
252 yieldOp->getOperand(0).getDefiningOp() != oper);
253}
254
255bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
256 // All basic elemwise checks.
258 return false;
259
260 // Check input is actully used.
261 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
262 return false;
263 return true;
264}
265
266bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
268 return false;
269
270 // Check both inputs are used (elementwise).
271 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
272 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
273 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
274 !op.payloadUsesValueFromOperand(inputOpOperand1));
275}
276
277//===----------------------------------------------------------------------===//
278// ContractionOpInterface implementation
279//===----------------------------------------------------------------------===//
280
281/// If the value is defined by a chain of unary side effect-free, go up the
282/// use-def chain until the first value that isn't defined by such an op.
283// TODO: relax to multi-operands with constants, which are technically unary ops
284// as needed (e.g. add5).
286 Operation *op = value.getDefiningOp();
287 while (op && op->getNumOperands() == 1) {
288 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
289 if (!iface || !iface.hasNoEffect())
290 break;
291 value = op->getOperand(0);
292 op = value.getDefiningOp();
293 }
294 return value;
295}
296
298 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
299 llvm::raw_ostream &errs) {
300 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
301 errs << "no terminator in the block";
302 return false;
303 }
304
305 if (block.getNumArguments() != 3) {
306 errs << "expected block with 3 arguments";
307 return false;
308 }
309
310 Operation *terminator = block.getTerminator();
311 if (terminator->getNumOperands() != 1) {
312 errs << "expected terminator with 1 operand";
313 return false;
314 }
315
316 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
317 Operation *reductionOp = yielded.getDefiningOp();
318 if (!reductionOp || reductionOp->getNumResults() != 1 ||
319 reductionOp->getNumOperands() != 2) {
320 errs << "expected reduction op to be binary";
321 return false;
322 }
323
324 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
325 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
326
327 if (reductionLHS != block.getArgument(2) &&
328 reductionRHS != block.getArgument(2)) {
329 errs << "expected reduction to take block argument #2 as one of the "
330 "operands (modulo unary casts)";
331 return false;
332 }
333
334 Value contributed = getSourceSkipUnary(
335 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
336 Operation *elementwiseOp = contributed.getDefiningOp();
337 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
338 elementwiseOp->getNumOperands() != 2) {
339 errs << "expected elementwise op to be binary";
340 return false;
341 }
342
343 if (!isaPair(elementwiseOp, reductionOp)) {
344 errs << "expected reduction/elementwise op kind not satisfied";
345 return false;
346 }
347
348 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
349 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
350 if ((elementwiseLHS == block.getArgument(0) &&
351 elementwiseRHS == block.getArgument(1)) ||
352 (elementwiseLHS == block.getArgument(1) &&
353 elementwiseRHS == block.getArgument(0))) {
354 return true;
355 }
356
357 errs << "expected elementwise op to apply to block arguments (modulo unary "
358 "casts)";
359 return false;
360}
361
362/// Returns true if the two operations are of the kinds specified by a pair of
363/// consecutive template arguments.
364template <typename AddOpTy, typename MulOpTy, typename... Args>
366 static_assert(sizeof...(Args) % 2 == 0,
367 "expected an even number of template arguments");
368 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
369 return true;
370
371 if constexpr (sizeof...(Args) > 0)
373 else
374 return false;
375}
376
377/// Returns true if the block is a body of a contraction with the kinds of
378/// operations given pairwise by template arguments.
379template <typename... Args>
383
384/// Given an `indexingMap` and its corresponding `iterators`, returns
385/// the positions of the iterators of type `iter` that are indexed by
386/// the `indexingMap` as a permutation. This is useful to infer various
387/// subcomputations on a `LinalgOp`. This is performed by looking up
388/// each result in the `indexingMap` and determining whether:
389/// - It is a single AffineDimExpr.
390/// - It is the only result involving this AffineDimExpr.
391static llvm::SmallDenseSet<int64_t>
394 utils::IteratorType iter) {
395 assert(iterators.size() == indexingMap.getNumDims());
396 llvm::SmallDenseSet<int64_t> res;
397 for (AffineExpr e : indexingMap.getResults()) {
398 if (auto d = dyn_cast<AffineDimExpr>(e)) {
399 if (iterators[d.getPosition()] == iter &&
400 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
401 return e.isFunctionOfDim(d.getPosition());
402 }) == 1)
403 res.insert(d.getPosition());
404 }
405 }
406 return res;
407}
408
409namespace {
410auto par = utils::IteratorType::parallel;
411auto red = utils::IteratorType::reduction;
412} // namespace
413
414/// Infer the iterator types from the init affine map. This looks at which dims
415/// are present in the map results, and returns an iterator types array with
416/// parallel types for dims that are present, and reduction types for dims that
417/// are not present.
418static FailureOr<SmallVector<utils::IteratorType>>
420 if (!map.isProjectedPermutation())
421 return failure();
422 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
423 for (auto expr : map.getResults())
424 if (auto dim = dyn_cast<AffineDimExpr>(expr))
425 iterators[dim.getPosition()] = par;
426 return iterators;
427}
428
429/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
430/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
431/// 1. The m dimension is involved in an outer-product along LHS
432/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
433/// 2. The n dimension is involved in an outer-product along RHS
434/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
435/// 3. The k dimension appears as a permutation on LHS and RHS.
436/// 4. m, n and k appear only once in any given indexing.
437/// 5. Optional batch dimensions that appear in all operands are captured.
438/// This allows e.g. detecting that some contraction is embedded within
439/// `linalgOp` with some orthogonal heuristic.
440static FailureOr<ContractionDimensions>
443 llvm::SmallDenseSet<int64_t> a =
444 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
445 llvm::SmallDenseSet<int64_t> b =
446 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
447 llvm::SmallDenseSet<int64_t> c =
448 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
449
450 // A & C - B are the iterators involved in an outer-product along A (the LHS).
451 llvm::SmallDenseSet<int64_t> ac = a;
452 llvm::set_intersect(ac, c);
453 llvm::set_subtract(ac, b);
454 // B & C - A are the iterators involved in an outer-product along B (the RHS).
455 llvm::SmallDenseSet<int64_t> bc = b;
456 llvm::set_intersect(bc, c);
457 llvm::set_subtract(bc, a);
458 // A & B & C are the "batch" dimensions.
459 llvm::SmallDenseSet<int64_t> batches = a;
460 llvm::set_intersect(batches, b);
461 llvm::set_intersect(batches, c);
462
463 // A & B red are the reduction dimensions.
464 llvm::SmallDenseSet<int64_t> ra =
465 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
466 llvm::SmallDenseSet<int64_t> rb =
467 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
468 llvm::set_intersect(ra, rb);
469
470 // Return each set in sorted order.
471 ContractionDimensions dimensions{
472 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
473 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
474 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
475 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
476 llvm::sort(dimensions.batch);
477 llvm::sort(dimensions.m);
478 llvm::sort(dimensions.n);
479 llvm::sort(dimensions.k);
480 return dimensions;
481}
482
483FailureOr<ContractionDimensions>
485 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
486 return failure();
487 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
488 linalgOp.getIteratorTypesArray());
489}
490
491FailureOr<ContractionDimensions>
493 if (indexingMaps.size() != 3)
494 return failure();
495 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
496 if (failed(iterators))
497 return failure();
498 return inferContractionDimsImpl(indexingMaps, iterators.value());
499}
500
501namespace mlir::linalg::detail {
510} // namespace mlir::linalg::detail
511
515 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
516 if (!linalgOp)
518 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
520 auto mapRange = linalgOp.getIndexingMapsArray();
521 if (linalgOp.getNumReductionLoops() == 0)
523 if (llvm::any_of(mapRange,
524 [](AffineMap m) { return !m.isProjectedPermutation(); }))
526 // TODO: more fields than add/mul.
527 // clang-format off
529 arith::MulFOp, arith::AddFOp,
530 arith::MulIOp, arith::AddIOp,
531 complex::MulOp, complex::AddOp,
532 arith::AndIOp, arith::OrIOp>(
533 *linalgOp.getBlock())) {
535 }
536 // clang-format on
537
538 if (dimensions) {
539 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
540 assert(succeeded(res) && "unexpected failure to infer contraction dims");
541 *dimensions = *res;
542 }
544}
545
546StringRef
548 switch (res) {
550 return "expected a LinalgOp";
552 return "expected op with 2 inputs and 1 output";
554 return "expected at least 1 reduction";
556 return "expected indexing maps to be projected permutations";
558 return "expected add/mul op in the body";
560 return "";
561 }
562 llvm_unreachable("unhandled MatchContractionResult case");
563}
564
566 if (!linalgOp)
567 return false;
568 Operation *op = linalgOp.getOperation();
569 return isa<ContractionOpInterface>(op) ||
572}
573
574/// Verify that a LinalgOp `op` is a contraction.
575/// A Linalg contraction is defined in general terms:
576/// 1. Has 2 input and 1 output shapes.
577/// 2. Has at least one reduction dimension.
578/// 3. Has only projected permutation indexing maps.
579/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
580/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
581/// operations that may change the type (e.g. for mixed-precision).
582/// As a consequence, when vectorization of such an op occurs, the only special
583/// behavior is that the (unique) MulOpType is vectorized into a
584/// `vector.contract`. All other ops are handled in a generic fashion.
585/// In the future, we may wish to allow more input arguments and elementwise and
586/// constant operations that do not involve the reduction dimension(s).
593
594//===----------------------------------------------------------------------===//
595// ConvolutionOpInterface implementation
596//===----------------------------------------------------------------------===//
597
598/// Of the given two expressions returns one that is of type T (`lhs` gets
599/// preference over `rhs`)
600template <typename T>
602 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
603}
604
605namespace {
606/// Walk the indexing expressions for input of a convolution operation to verify
607/// its of the right form, either
608/// - AffineDimExpr
609/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
610/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
611///
612/// classifies the AffineDimExpr as convolved dimensions or unconvolved
613/// dimensions and verifies each dimension occurs only once.
614struct ConvAccessExprWalker
615 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
616 // Stores dimensions used in expressions of the above form.
617 llvm::SmallDenseSet<int64_t> convolvedDims;
618 // Stores the dual mapping between LHS and RHS of convolution exprs.
619 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
620 // Stores single use dimensions used by an AffineDimExpr.
621 llvm::SmallDenseSet<int64_t> unConvolvedDims;
622 // Stores a mapping from convolved dims to their coefficient.
623 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
624
625 // Removes dims with multiple uses in the source input map from dimension
626 // sets tracked by this walker.
627 void clearMultiUseDims(AffineMap map) {
628 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
629 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
630 return e.isFunctionOfDim(dimPos);
631 }) > 1) {
632 convolvedDims.erase(dimPos);
633 unConvolvedDims.erase(dimPos);
634 // If a duplicate dim is marked as convolved, the pair of the duplicate
635 // dim must be removed from the map as well.
636 auto it = convolvedDimMapping.find(dimPos);
637 if (it != convolvedDimMapping.end()) {
638 int64_t pairedDim = it->second;
639 convolvedDims.erase(pairedDim);
640 unConvolvedDims.erase(pairedDim);
641 strideAndDilationMapping.erase(pairedDim);
642 convolvedDimMapping.erase(dimPos);
643 convolvedDimMapping.erase(pairedDim);
644 }
645 }
646 }
647 }
648
649 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
650 unsigned position = dimExpr.getPosition();
651 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
652 return failure();
653 }
654 unConvolvedDims.insert(position);
655 return success();
656 }
657
658 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
659
660 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
661
662 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
663 // In pre-order visit, top level op has to be an add op.
664 if (binaryExpr.getKind() != AffineExprKind::Add)
665 return failure();
666 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
667 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
668 if (failed(lhsDimPos) || failed(rhsDimPos))
669 return failure();
670 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
671 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
672 return success();
673 }
674
675 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
676 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
677 int64_t dim = dimExpr.getPosition();
678 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
679 return failure();
680 // Stride/dilation for this dim is implicitly 1.
681 strideAndDilationMapping[dim] =
683 convolvedDims.insert(dim);
684 return dim;
685 }
686 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
687 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
688 return failure();
689 auto lhsExpr = symbolMulExpr.getLHS();
690 auto rhsExpr = symbolMulExpr.getRHS();
691 // Check for symbol expression.
692 AffineExpr mulExpr =
694 // If there was no symbol expr, check for constant expression.
695 if (!mulExpr) {
696 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
697 }
698 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
699 if (!mulExpr || !dimExpr)
700 return failure();
701 int64_t dim = dimExpr.getPosition();
702 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
703 return failure();
704 strideAndDilationMapping[dim] = mulExpr;
705 convolvedDims.insert(dim);
706 return dim;
707 }
708 return failure();
709 }
710};
711} // namespace
712
713static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
714 assert(map.isProjectedPermutation() &&
715 "expected map to have projected permutations");
716 llvm::SmallDenseSet<int64_t> preservedDims;
717 for (auto expr : map.getResults())
718 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
719 return preservedDims;
720}
721
725 for (auto e : exprs) {
726 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
727 assert(constantExpr && "Found non-constant stride/dilation");
728 vals.push_back(constantExpr.getValue());
729 }
730 return vals;
731}
732
733/// Classifies dimensions in the `linalgOp` used by a convolution
734/// subcomputation, as captured by `inputExprWalker`. If
735/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
736/// at least convolved dimension pair (output image + filter loop). Convolution
737/// dimensions are specified in sorted order, and strides match the order of
738/// the filter loop dimensions, while the dilations match the order of the
739/// output image dimensions.
740static FailureOr<ConvolutionDimensions>
741inferConvolutionDimsImpl(LinalgOp linalgOp,
742 ConvAccessExprWalker &inputExprWalker,
743 bool allowEmptyConvolvedDims) {
744 auto filterMap =
745 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
746 auto outputMap =
747 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
748 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
749 filterMap, linalgOp.getIteratorTypesArray(), par);
750 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
751 outputMap, linalgOp.getIteratorTypesArray(), par);
752
753 // unConvolvedDims & outputDims - filterDims are the batch iterators.
754 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
755 llvm::set_intersect(batch, outputDims);
756 llvm::set_subtract(batch, filterDims);
757
758 // convolvedDims & outputDims are the output image iterators.
759 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
760 llvm::set_intersect(oi, outputDims);
761
762 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
763 llvm::SmallDenseSet<int64_t> oc = filterDims;
764 llvm::set_intersect(oc, outputDims);
765 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
766
767 // filterDims & outputDims & unConvolvedDims are the depth iterators.
768 llvm::SmallDenseSet<int64_t> depth = filterDims;
769 llvm::set_intersect(depth, outputDims);
770 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
771
772 llvm::SmallDenseSet<int64_t> filterReducedDims =
774 linalgOp.getIteratorTypesArray(), red);
775
776 // convolvedDims & filterReducedDims are the filter loop iterators.
777 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
778 llvm::set_intersect(fl, filterReducedDims);
779
780 // unConvolvedDims & filterReducedDims are the input channel iterators.
781 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
782 llvm::set_intersect(ic, filterReducedDims);
783
784 if (oi.empty() && !allowEmptyConvolvedDims)
785 return failure();
786
787 // Return each set in sorted order.
788 ConvolutionDimensions dimensions{
789 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
790 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
791 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
792 SmallVector<unsigned, 2>(fl.begin(), fl.end()),
793 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
794 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
795 /*strides=*/SmallVector<int64_t, 2>{},
796 /*dilations=*/SmallVector<int64_t, 2>{}};
797 llvm::sort(dimensions.batch);
798 llvm::sort(dimensions.outputImage);
799 llvm::sort(dimensions.outputChannel);
800 llvm::sort(dimensions.filterLoop);
801 llvm::sort(dimensions.inputChannel);
802 llvm::sort(dimensions.depth);
803
804 // Use the op carried strides/dilations attribute if present.
805 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
806 if (!nativeStrides) {
807 SmallVector<AffineExpr, 2> strideExprs;
808 for (unsigned oiDim : dimensions.outputImage)
809 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
810 dimensions.strides = getConstantsFromExprList(strideExprs);
811 } else {
812 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
813 }
814 auto nativeDilations =
815 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
816 if (!nativeDilations) {
817 SmallVector<AffineExpr, 2> dilationExprs;
818 for (unsigned flDim : dimensions.filterLoop)
819 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
820 dimensions.dilations = getConstantsFromExprList(dilationExprs);
821 } else {
822 dimensions.dilations =
823 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
824 }
825 return dimensions;
826}
827
828/// Find at least 1 parallel (output_image) and reduction (filter_loop)
829/// dimension candidates that form a convolution subcomputation within
830/// `linalgOp`. The LHS is assumed to be the convolution input while the
831/// RHS is assumed as the filter.
832/// These dimensions are such that:
833/// 1. Optional batch dimensions that appear in the input and filter.
834/// 2. The output_image dimension is involved in a cross-correlation along LHS
835/// (i.e. it is a permutation on RES and LHS and has an associated
836/// filter_loop in RHS).
837/// 3. Optional output_channel dimension is involved in an outer-product along
838/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
839/// LHS).
840/// 4. Optional input_channel dimension appears as a permutation on LHS and
841/// RHS.
842/// 5. The filter_loop dimension appears as a permutation on the RHS and
843/// represents the shape of the kernel cross-correlated along a
844/// corresponding output_image dim.
845/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
846/// 7. All dimensions appear only once in any given indexing map.
847/// This allows e.g. detecting that some convolution is embedded within
848/// `linalgOp` with some orthogonal heuristic.
849/// When multiple dimension occurrences exist that match any classification
850/// indices are returned in sorted order.
851/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
852FailureOr<ConvolutionDimensions>
854 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
855 return failure();
856
857 auto indexingMaps = linalgOp.getIndexingMapsArray();
858
859 // Check the input indexing map has the right form.
860 ConvAccessExprWalker inputExprWalker;
861 for (AffineExpr expr : indexingMaps[0].getResults())
862 (void)inputExprWalker.visit(expr);
863 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
864
865 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
866 /*allowEmptyConvolvedDims=*/false);
867}
868
869namespace mlir::linalg::detail {
881} // namespace mlir::linalg::detail
882
885 Operation *op, ConvolutionDimensions *dimensions,
886 bool allowEmptyConvolvedDims) {
887 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
888 if (!linalgOp)
890 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
892
893 auto indexingMaps = linalgOp.getIndexingMapsArray();
894
895 // Check the input indexing map has the right form.
896 ConvAccessExprWalker inputExprWalker;
897 if (llvm::any_of(indexingMaps[0].getResults(),
898 [&inputExprWalker](AffineExpr expr) {
899 return failed(inputExprWalker.visit(expr));
900 })) {
902 }
903
904 // Filter and output maps must be projected permutation.
905 if (!indexingMaps[1].isProjectedPermutation() ||
906 !indexingMaps.back().isProjectedPermutation())
908
909 auto iteratorTypes = linalgOp.getIteratorTypesArray();
910
911 llvm::SmallDenseSet<int64_t> outputDims =
912 getPreservedDims(indexingMaps.back());
913 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
914 // Make sure all loops are characterized as one of:
915 // - Batch loop : present in output, as non-convolved in input, not present in
916 // filter.
917 // - Output image dimension : present in output, convolved dims in input, not
918 // present in filter.
919 // - Output channel dimension : present in output, not present in input,
920 // present in filter.
921 // - Filter loop dimension : present in filter, convolved in input, not
922 // present in output.
923 // - Input channel dimension : unconvolved in input, not present in output,
924 // present in filter.
925 // - Depth multiplier : unconvolved in input, present in output, present in
926 // filter.
927 llvm::SmallDenseSet<int64_t> allLoopDims;
928 for (auto outputExpr : indexingMaps.back().getResults()) {
929 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
930 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
931 !filterDims.count(outputDim)) {
932 // Batch dimension.
933 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
935 allLoopDims.insert(outputDim);
936 continue;
937 }
938 if (inputExprWalker.convolvedDims.count(outputDim) &&
939 !filterDims.count(outputDim)) {
940 // Output image Loop dimension.
941 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
943 allLoopDims.insert(outputDim);
944 continue;
945 }
946 if (!inputExprWalker.convolvedDims.count(outputDim) &&
947 !inputExprWalker.unConvolvedDims.count(outputDim) &&
948 filterDims.count(outputDim)) {
949 // Output channel dimension.
950 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
952 allLoopDims.insert(outputDim);
953 continue;
954 }
955 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
956 filterDims.count(outputDim)) {
957 // Depth multiplier.
958 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
960 allLoopDims.insert(outputDim);
961 continue;
962 }
964 }
965 for (auto filterExpr : indexingMaps[1].getResults()) {
966 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
967 if (outputDims.count(filterDim) &&
968 !inputExprWalker.unConvolvedDims.count(filterDim) &&
969 !inputExprWalker.convolvedDims.count(filterDim)) {
970 // Output channel dimension. This is already seen, continue;
971 continue;
972 }
973 if (inputExprWalker.convolvedDims.count(filterDim) &&
974 !outputDims.count(filterDim)) {
975 // Filter loop dimension.
976 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
978 if (allLoopDims.count(filterDim))
980 allLoopDims.insert(filterDim);
981 continue;
982 }
983 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
984 !outputDims.count(filterDim)) {
985 // Input channel dimension.
986 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
988 if (allLoopDims.count(filterDim))
990 allLoopDims.insert(filterDim);
991 continue;
992 }
993 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
994 outputDims.count(filterDim)) {
995 // Depthwise loop. Already seen.
996 continue;
997 }
999 }
1000 // All loops must be covered now.
1001 if (allLoopDims.size() != linalgOp.getNumLoops())
1003
1004 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1006
1007 if (dimensions) {
1008 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1009 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1010 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1011 *dimensions = *res;
1012 }
1013
1015}
1016
1017StringRef
1019 switch (res) {
1021 return "expected a LinalgOp";
1023 return "expected op with 2 inputs and 1 output";
1025 return "unexpected input index map for convolutions";
1027 return "expected output/filter indexing maps to be projected permutations";
1029 return "unexpected loop dimension for convolution op";
1031 return "expected all iterators used to access outputs to be parallel";
1033 return "expected all iterators not used to access outputs to be reduction";
1035 return "expected convolved dim to be non-empty";
1037 return "";
1038 }
1039 llvm_unreachable("unhandled MatchConvolutionResult case");
1040}
1041
1043 bool allowEmptyConvolvedDims) {
1045 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1047}
1048
1055
1056//===----------------------------------------------------------------------===//
1057// FillOpInterface implementation
1058//===----------------------------------------------------------------------===//
1059
1066
1068 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1069 if (!linalgOp)
1071 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1073
1074 OpOperand *value = linalgOp.getDpsInputOperand(0);
1075 if (!linalgOp.isScalar(value))
1077
1079}
1080
1082 auto res = isFillInterfaceImpl(op);
1084 return op->emitError("expected a LinalgOp");
1086 return op->emitError("expected op with 1 input and 1 output");
1088 return op->emitError("expected op with scalar input");
1089
1090 return success();
1091}
1092
1093//===----------------------------------------------------------------------===//
1094// StructuredOpInterface implementation
1095//===----------------------------------------------------------------------===//
1096
1097SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1098 Location loc) {
1100 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1101 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1102 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1103 }
1104 return res;
1105}
1106
1107SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1109 assert(!hasDynamicShape() && "expected operands to have static shapes");
1110 for (OpOperand &opOperand : getOperation()->getOpOperands())
1111 llvm::append_range(res, getShape(&opOperand));
1112 return res;
1113}
1114
1115SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1116 AffineMap map = getLoopsToShapesMap();
1117 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1118 auto viewSizes = createFlatListOfOperandDims(b, loc);
1119 SmallVector<Range, 4> res(numDims);
1120 for (unsigned idx = 0; idx < numRes; ++idx) {
1121 auto result = map.getResult(idx);
1122 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1123 if (res[d.getPosition()].offset)
1124 continue;
1125 res[d.getPosition()] =
1126 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1127 }
1128 }
1129 return res;
1130}
1131
1132/// Visitor to check if any of the given set of positions from AffineDimExprs
1133/// are used within an AffineExpr.
1135 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1136 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1137 : positions(std::move(positions)) {}
1138
1140 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1141 }
1142
1144 return positions.test(dimExpr.getPosition());
1145 }
1146
1147 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1148
1149 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1150
1151private:
1152 llvm::SmallBitVector positions;
1153};
1154
1155static std::pair<int64_t, int64_t>
1157 int64_t inputRankSum = 0;
1158 int64_t outputRankSum = 0;
1159 for (OpOperand *input : op.getDpsInputOperands())
1160 inputRankSum += op.getRank(input);
1161 for (OpOperand &output : op.getDpsInitsMutable())
1162 outputRankSum += op.getRank(&output);
1163 return {inputRankSum, inputRankSum + outputRankSum};
1164}
1165
1166LogicalResult
1167LinalgOp::reifyResultShapes(OpBuilder &b,
1168 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1169 // An example that helps understand the logic below.
1170 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1171 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1172 // This is achieved as follows.
1173 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1174 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1175 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1176 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1177 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1178 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1179
1180 // Find the position in the above map that represents the shape of the
1181 // result:dim being inferred.
1182 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1183
1184 /// From loopsToShapesMap extract the submap that represents the shape of the
1185 /// (resultIdx, dim) needed.
1186 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1187 resultShapesSubMapPos.first,
1188 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1189 AffineMap resultShapesFromInputShapesMap =
1190 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1191
1192 // Check that the result dim map does not contain the positions corresponding
1193 // to the outputs.
1194 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1195 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1196 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1197 Location loc = getOperation()->getLoc();
1198 IRRewriter rewriter(b);
1199 SmallVector<OpFoldResult> allResultDimValues =
1201 rewriter, loc, resultShapesFromInputShapesMap,
1202 createFlatListOfOperandDims(b, loc));
1203 int64_t pos = 0;
1204 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1205 for (OpOperand &opOperand : getDpsInitsMutable()) {
1206 SmallVector<OpFoldResult> shapes;
1207 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1208 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1209 if (!shapedType.isDynamicDim(dim)) {
1210 // Static dim: Return IntegerAttr.
1211 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1212 } else {
1213 // Dynamic dim: Return Value.
1214 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1215 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1216 : allResultDimValues[pos];
1217 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1218 }
1219 pos++;
1220 }
1221 reifiedReturnShapes.emplace_back(std::move(shapes));
1222 }
1223 return success();
1224}
1225
1226/// Return the index in the indexingMaps vector that corresponds to this
1227/// `opOperand`.
1228int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1229 auto operandNumber = opOperand->getOperandNumber();
1230 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1231 if (!dpsIface.isDpsInput(opOperand))
1232 return operandNumber;
1233 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1234 assert(!dpsIface.isDpsInit(opOperand));
1235 // Account for potential inputs that are not DPS and may not appear in
1236 // `indexingMaps`.
1237 return cast<DestinationStyleOpInterface>(*this->getOperation())
1238 .getNumDpsInputs() +
1239 operandNumber - start;
1240}
1241
1243 LinalgOp linalgOp = cast<LinalgOp>(op);
1244 // Mixed tensor/buffer operands are not allowed.
1245 if (!linalgOp.hasPureTensorSemantics() &&
1246 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1247 return op->emitOpError("expected to have pure tensor or buffer semantics");
1248
1249 // Before checking indexing maps, we need to make sure the attributes
1250 // referenced by it are valid.
1251 if (linalgOp.hasDynamicIndexingMaps())
1252 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1253 return failure();
1254
1255 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1256 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1257 return failure();
1258
1259 // Set this flag if this op has user defined maps. This is required to guard
1260 // the below error condition which assume default indexing maps.
1261 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1262 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1263 // Domain must be consistent.
1264 unsigned numLoops = linalgOp.getNumLoops();
1265 if (indexingMap.getNumDims() != numLoops)
1266 return op->emitOpError("expected indexing_map #")
1267 << opOperand.getOperandNumber() << " to have " << numLoops
1268 << " dim(s) to match the number of loops";
1269 }
1270 SmallVector<unsigned> redDims;
1271 linalgOp.getReductionDims(redDims);
1272
1273 if (!linalgOp.getShapesToLoopsMap())
1274 return op->emitOpError("expected the shape-to-loops map to be non-null");
1275
1276 // Check the region has exactly one block.
1277 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1278 return op->emitOpError("expects to have 1 region with 1 block");
1279
1280 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1281 // types.
1282 // TODO: once ranked shape types are plugged in, we may want to drop the
1283 // corresponding bbargs, that can never be read from. This will be subject to
1284 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1285 // not used).
1286 Block &block = linalgOp->getRegion(0).front();
1287
1288 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1289 return op->emitOpError("expected as many non-induction variable region "
1290 "arguments as the number of input/output operands");
1291
1292 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1293 Type elementType = opOperand->get().getType();
1294 if (isa<MemRefType, RankedTensorType>(elementType))
1295 elementType = getElementTypeOrSelf(opOperand->get().getType());
1296 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1297 if (elementType != argType)
1298 return op->emitOpError("expected type of bb argument #")
1299 << opOperand->getOperandNumber() << " (" << argType << ")"
1300 << " to match element or self type of the corresponding operand ("
1301 << elementType << ")";
1302 }
1303
1304 return success();
1305}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop