MLIR 23.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 if (auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
143 return SmallVector<int64_t>(broadcastOp.getDimensions().begin(),
144 broadcastOp.getDimensions().end());
145
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
147 if (!op)
148 return std::nullopt;
149
150 // Structural.
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
153 return std::nullopt;
154
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
159 return std::nullopt;
160
161 // Check output is identity map. Broadcast could additionally be
162 // employing permutation of indices and that would be expressible
163 // in linalg.generic but is not expressible for named broadcast op.
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
166 return std::nullopt;
167
168 SmallVector<int64_t> position;
169 auto srcMap = op.getIndexingMapsArray()[0];
170
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
172 return std::nullopt;
173
174 // Check input map is monotonically increasing DimIds.
175 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
177 if (!expr)
178 return std::nullopt;
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
181 return std::nullopt;
182 position.push_back(expr.getPosition());
183 }
184
185 SmallVector<int64_t> broadcastedDims;
186 auto numDims = srcMap.getNumDims();
187 // This is quadratic but number of items is generally small.
188 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
191 }
192 return broadcastedDims;
193}
194
195//===----------------------------------------------------------------------===//
196// TransposeOpInterface implementation
197//===----------------------------------------------------------------------===//
198std::optional<SmallVector<int64_t>>
200 // To specialize as a transpose op, the genericOp must be
201 // all parallel loops, single input, single output, and its body
202 // should be just a yield op, yielding input as output as is (no compute).
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
205 return std::nullopt;
206
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
209 return std::nullopt;
210
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
213
214 // linalg.transpose permutes the dimensions of input using this
215 // rule: dim(result, i) = dim(input, permutation[i])
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
217 return std::nullopt;
218
219 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
220 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
223 }
224 return permutation;
225}
226
227//===----------------------------------------------------------------------===//
228// Elementwise Single Unary/Binary-OpInterface implementation
229//===----------------------------------------------------------------------===//
230static bool
231isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
232 bool allowNonIdentityMaps) {
233 // Check all loops are parallel.
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
235 return false;
236
237 // Check there are arity-inputs, 1-output and all are identity-maps (unless
238 // requested otherwise).
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](AffineMap map) { return map.isIdentity(); })))
243 return false;
244
245 // Init should not be referenced for elementwise operations.
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
247 return false;
248
249 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
250 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
251 // the body, where the first is the elementwise single op and the second a
252 // yield.
253 Block *body = op.getBody();
254 if (body->getOperations().size() != 2)
255 return false;
256
257 // The payload op must have one result and at least arity-many operands
258 // (otherwise not all inputs can be used). It can have additional operands
259 // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
260 // use an input more than once (e.g. mul(x, x) for linalg.square).
261 Operation *oper = &body->front();
262 if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
263 return false;
264
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
268}
269
270bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
271 bool allowNonIdentityMaps) {
272 // All basic elemwise checks.
273 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
274 return false;
275
276 // Check input is actually used.
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
278 return false;
279 return true;
280}
281
282bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
284 op, 2, /*allowNonIdentityMaps=*/false))
285 return false;
286
287 // Check both inputs are used (elementwise).
288 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
289 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
290 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
291 !op.payloadUsesValueFromOperand(inputOpOperand1));
292}
293
294//===----------------------------------------------------------------------===//
295// ContractionOpInterface implementation
296//===----------------------------------------------------------------------===//
297
298/// If the value is defined by a chain of unary side effect-free, go up the
299/// use-def chain until the first value that isn't defined by such an op.
300// TODO: relax to multi-operands with constants, which are technically unary ops
301// as needed (e.g. add5).
303 Operation *op = value.getDefiningOp();
304 while (op && op->getNumOperands() == 1) {
305 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
306 if (!iface || !iface.hasNoEffect())
307 break;
308 value = op->getOperand(0);
309 op = value.getDefiningOp();
310 }
311 return value;
312}
313
315 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
316 llvm::raw_ostream &errs) {
317 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
318 errs << "no terminator in the block";
319 return false;
320 }
321
322 if (block.getNumArguments() != 3) {
323 errs << "expected block with 3 arguments";
324 return false;
325 }
326
327 Operation *terminator = block.getTerminator();
328 if (terminator->getNumOperands() != 1) {
329 errs << "expected terminator with 1 operand";
330 return false;
331 }
332
333 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
334 Operation *reductionOp = yielded.getDefiningOp();
335 if (!reductionOp || reductionOp->getNumResults() != 1 ||
336 reductionOp->getNumOperands() != 2) {
337 errs << "expected reduction op to be binary";
338 return false;
339 }
340
341 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
342 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
343
344 if (reductionLHS != block.getArgument(2) &&
345 reductionRHS != block.getArgument(2)) {
346 errs << "expected reduction to take block argument #2 as one of the "
347 "operands (modulo unary casts)";
348 return false;
349 }
350
351 Value contributed = getSourceSkipUnary(
352 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
353 Operation *elementwiseOp = contributed.getDefiningOp();
354 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
355 elementwiseOp->getNumOperands() != 2) {
356 errs << "expected elementwise op to be binary";
357 return false;
358 }
359
360 if (!isaPair(elementwiseOp, reductionOp)) {
361 errs << "expected reduction/elementwise op kind not satisfied";
362 return false;
363 }
364
365 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
366 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
367 if ((elementwiseLHS == block.getArgument(0) &&
368 elementwiseRHS == block.getArgument(1)) ||
369 (elementwiseLHS == block.getArgument(1) &&
370 elementwiseRHS == block.getArgument(0))) {
371 return true;
372 }
373
374 errs << "expected elementwise op to apply to block arguments (modulo unary "
375 "casts)";
376 return false;
377}
378
379/// Returns true if the two operations are of the kinds specified by a pair of
380/// consecutive template arguments.
381template <typename AddOpTy, typename MulOpTy, typename... Args>
383 static_assert(sizeof...(Args) % 2 == 0,
384 "expected an even number of template arguments");
385 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
386 return true;
387
388 if constexpr (sizeof...(Args) > 0)
390 else
391 return false;
392}
393
394/// Returns true if the block is a body of a contraction with the kinds of
395/// operations given pairwise by template arguments.
396template <typename... Args>
400
401/// Given an `indexingMap` and its corresponding `iterators`, returns
402/// the positions of the iterators of type `iter` that are indexed by
403/// the `indexingMap` as a permutation. This is useful to infer various
404/// subcomputations on a `LinalgOp`. This is performed by looking up
405/// each result in the `indexingMap` and determining whether:
406/// - It is a single AffineDimExpr.
407/// - It is the only result involving this AffineDimExpr.
408static llvm::SmallDenseSet<int64_t>
411 utils::IteratorType iter) {
412 assert(iterators.size() == indexingMap.getNumDims());
413 llvm::SmallDenseSet<int64_t> res;
414 for (AffineExpr e : indexingMap.getResults()) {
415 if (auto d = dyn_cast<AffineDimExpr>(e)) {
416 if (iterators[d.getPosition()] == iter &&
417 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
418 return e.isFunctionOfDim(d.getPosition());
419 }) == 1)
420 res.insert(d.getPosition());
421 }
422 }
423 return res;
424}
425
426namespace {
427auto par = utils::IteratorType::parallel;
428auto red = utils::IteratorType::reduction;
429} // namespace
430
431/// Infer the iterator types from the init affine map. This looks at which dims
432/// are present in the map results, and returns an iterator types array with
433/// parallel types for dims that are present, and reduction types for dims that
434/// are not present.
435static FailureOr<SmallVector<utils::IteratorType>>
437 if (!map.isProjectedPermutation())
438 return failure();
439 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
440 for (auto expr : map.getResults())
441 if (auto dim = dyn_cast<AffineDimExpr>(expr))
442 iterators[dim.getPosition()] = par;
443 return iterators;
444}
445
446/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
447/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
448/// 1. The m dimension is involved in an outer-product along LHS
449/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
450/// 2. The n dimension is involved in an outer-product along RHS
451/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
452/// 3. The k dimension appears as a permutation on LHS and RHS.
453/// 4. m, n and k appear only once in any given indexing.
454/// 5. Optional batch dimensions that appear in all operands are captured.
455/// This allows e.g. detecting that some contraction is embedded within
456/// `linalgOp` with some orthogonal heuristic.
457static FailureOr<ContractionDimensions>
460 llvm::SmallDenseSet<int64_t> a =
461 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
462 llvm::SmallDenseSet<int64_t> b =
463 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
464 llvm::SmallDenseSet<int64_t> c =
465 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
466
467 // A & C - B are the iterators involved in an outer-product along A (the LHS).
468 llvm::SmallDenseSet<int64_t> ac = a;
469 llvm::set_intersect(ac, c);
470 llvm::set_subtract(ac, b);
471 // B & C - A are the iterators involved in an outer-product along B (the RHS).
472 llvm::SmallDenseSet<int64_t> bc = b;
473 llvm::set_intersect(bc, c);
474 llvm::set_subtract(bc, a);
475 // A & B & C are the "batch" dimensions.
476 llvm::SmallDenseSet<int64_t> batches = a;
477 llvm::set_intersect(batches, b);
478 llvm::set_intersect(batches, c);
479
480 // A & B red are the reduction dimensions.
481 llvm::SmallDenseSet<int64_t> ra =
482 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
483 llvm::SmallDenseSet<int64_t> rb =
484 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
485 llvm::set_intersect(ra, rb);
486
487 // Return each set in sorted order.
488 ContractionDimensions dimensions{
489 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
490 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
491 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
492 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
493 llvm::sort(dimensions.batch);
494 llvm::sort(dimensions.m);
495 llvm::sort(dimensions.n);
496 llvm::sort(dimensions.k);
497 return dimensions;
498}
499
500FailureOr<ContractionDimensions>
502 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
503 return failure();
504 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
505 linalgOp.getIteratorTypesArray());
506}
507
508FailureOr<ContractionDimensions>
510 if (indexingMaps.size() != 3)
511 return failure();
512 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
513 if (failed(iterators))
514 return failure();
515 return inferContractionDimsImpl(indexingMaps, iterators.value());
516}
517
518namespace mlir::linalg::detail {
527} // namespace mlir::linalg::detail
528
532 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
533 if (!linalgOp)
535 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
537 auto mapRange = linalgOp.getIndexingMapsArray();
538 if (linalgOp.getNumReductionLoops() == 0)
540 if (llvm::any_of(mapRange,
541 [](AffineMap m) { return !m.isProjectedPermutation(); }))
543 // TODO: more fields than add/mul.
544 // clang-format off
546 arith::MulFOp, arith::AddFOp,
547 arith::MulIOp, arith::AddIOp,
548 complex::MulOp, complex::AddOp,
549 arith::AndIOp, arith::OrIOp>(
550 *linalgOp.getBlock())) {
552 }
553 // clang-format on
554
555 if (dimensions) {
556 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
557 assert(succeeded(res) && "unexpected failure to infer contraction dims");
558 *dimensions = *res;
559 }
561}
562
563StringRef
565 switch (res) {
567 return "expected a LinalgOp";
569 return "expected op with 2 inputs and 1 output";
571 return "expected at least 1 reduction";
573 return "expected indexing maps to be projected permutations";
575 return "expected add/mul op in the body";
577 return "";
578 }
579 llvm_unreachable("unhandled MatchContractionResult case");
580}
581
583 if (!linalgOp)
584 return false;
585 Operation *op = linalgOp.getOperation();
586 return isa<ContractionOpInterface>(op) ||
589}
590
591/// Verify that a LinalgOp `op` is a contraction.
592/// A Linalg contraction is defined in general terms:
593/// 1. Has 2 input and 1 output shapes.
594/// 2. Has at least one reduction dimension.
595/// 3. Has only projected permutation indexing maps.
596/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
597/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
598/// operations that may change the type (e.g. for mixed-precision).
599/// As a consequence, when vectorization of such an op occurs, the only special
600/// behavior is that the (unique) MulOpType is vectorized into a
601/// `vector.contract`. All other ops are handled in a generic fashion.
602/// In the future, we may wish to allow more input arguments and elementwise and
603/// constant operations that do not involve the reduction dimension(s).
610
611//===----------------------------------------------------------------------===//
612// ConvolutionOpInterface implementation
613//===----------------------------------------------------------------------===//
614
615/// Of the given two expressions returns one that is of type T (`lhs` gets
616/// preference over `rhs`)
617template <typename T>
619 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
620}
621
622namespace {
623/// Walk the indexing expressions for input of a convolution operation to verify
624/// its of the right form, either
625/// - AffineDimExpr
626/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
627/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
628///
629/// classifies the AffineDimExpr as convolved dimensions or unconvolved
630/// dimensions and verifies each dimension occurs only once.
631struct ConvAccessExprWalker
632 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
633 // Stores dimensions used in expressions of the above form.
634 llvm::SmallDenseSet<int64_t> convolvedDims;
635 // Stores the dual mapping between LHS and RHS of convolution exprs.
636 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
637 // Stores single use dimensions used by an AffineDimExpr.
638 llvm::SmallDenseSet<int64_t> unConvolvedDims;
639 // Stores a mapping from convolved dims to their coefficient.
640 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
641
642 // Removes dims with multiple uses in the source input map from dimension
643 // sets tracked by this walker.
644 void clearMultiUseDims(AffineMap map) {
645 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
646 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
647 return e.isFunctionOfDim(dimPos);
648 }) > 1) {
649 convolvedDims.erase(dimPos);
650 unConvolvedDims.erase(dimPos);
651 // If a duplicate dim is marked as convolved, the pair of the duplicate
652 // dim must be removed from the map as well.
653 auto it = convolvedDimMapping.find(dimPos);
654 if (it != convolvedDimMapping.end()) {
655 int64_t pairedDim = it->second;
656 convolvedDims.erase(pairedDim);
657 unConvolvedDims.erase(pairedDim);
658 strideAndDilationMapping.erase(pairedDim);
659 convolvedDimMapping.erase(dimPos);
660 convolvedDimMapping.erase(pairedDim);
661 }
662 }
663 }
664 }
665
666 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
667 unsigned position = dimExpr.getPosition();
668 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
669 return failure();
670 }
671 unConvolvedDims.insert(position);
672 return success();
673 }
674
675 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
676
677 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
678
679 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
680 // In pre-order visit, top level op has to be an add op.
681 if (binaryExpr.getKind() != AffineExprKind::Add)
682 return failure();
683 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
684 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
685 if (failed(lhsDimPos) || failed(rhsDimPos))
686 return failure();
687 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
688 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
689 return success();
690 }
691
692 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
693 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
694 int64_t dim = dimExpr.getPosition();
695 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
696 return failure();
697 // Stride/dilation for this dim is implicitly 1.
698 strideAndDilationMapping[dim] =
700 convolvedDims.insert(dim);
701 return dim;
702 }
703 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
704 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
705 return failure();
706 auto lhsExpr = symbolMulExpr.getLHS();
707 auto rhsExpr = symbolMulExpr.getRHS();
708 // Check for symbol expression.
709 AffineExpr mulExpr =
711 // If there was no symbol expr, check for constant expression.
712 if (!mulExpr) {
713 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
714 }
715 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
716 if (!mulExpr || !dimExpr)
717 return failure();
718 int64_t dim = dimExpr.getPosition();
719 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
720 return failure();
721 strideAndDilationMapping[dim] = mulExpr;
722 convolvedDims.insert(dim);
723 return dim;
724 }
725 return failure();
726 }
727};
728} // namespace
729
730static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
731 assert(map.isProjectedPermutation() &&
732 "expected map to have projected permutations");
733 llvm::SmallDenseSet<int64_t> preservedDims;
734 for (auto expr : map.getResults())
735 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
736 return preservedDims;
737}
738
742 for (auto e : exprs) {
743 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
744 assert(constantExpr && "Found non-constant stride/dilation");
745 vals.push_back(constantExpr.getValue());
746 }
747 return vals;
748}
749
750/// Classifies dimensions in the `linalgOp` used by a convolution
751/// subcomputation, as captured by `inputExprWalker`. If
752/// `allowEmptyConvolvedDims` is not set this will fail if there is not
753/// at least one convolved dimension pair (output image + filter loop).
754///
755/// The returned dimensions are ordered as follows:
756/// - `outputImage` is sorted by dimension index.
757/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
758/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
759/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
760/// - `strides[i]` corresponds to `outputImage[i]`.
761/// - `dilations[i]` corresponds to `filterLoop[i]`.
762/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
763static FailureOr<ConvolutionDimensions>
764inferConvolutionDimsImpl(LinalgOp linalgOp,
765 ConvAccessExprWalker &inputExprWalker,
766 bool allowEmptyConvolvedDims) {
767 auto filterMap =
768 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
769 auto outputMap =
770 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
771 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
772 filterMap, linalgOp.getIteratorTypesArray(), par);
773 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
774 outputMap, linalgOp.getIteratorTypesArray(), par);
775
776 // unConvolvedDims & outputDims - filterDims are the batch iterators.
777 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
778 llvm::set_intersect(batch, outputDims);
779 llvm::set_subtract(batch, filterDims);
780
781 // convolvedDims & outputDims are the output image iterators.
782 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
783 llvm::set_intersect(oi, outputDims);
784
785 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
786 llvm::SmallDenseSet<int64_t> oc = filterDims;
787 llvm::set_intersect(oc, outputDims);
788 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
789
790 // filterDims & outputDims & unConvolvedDims are the depth iterators.
791 llvm::SmallDenseSet<int64_t> depth = filterDims;
792 llvm::set_intersect(depth, outputDims);
793 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
794
795 llvm::SmallDenseSet<int64_t> filterReducedDims =
797 linalgOp.getIteratorTypesArray(), red);
798
799 // convolvedDims & filterReducedDims are the filter loop iterators.
800 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
801 llvm::set_intersect(fl, filterReducedDims);
802
803 // unConvolvedDims & filterReducedDims are the input channel iterators.
804 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
805 llvm::set_intersect(ic, filterReducedDims);
806
807 if (oi.empty() && !allowEmptyConvolvedDims)
808 return failure();
809
810 // Return each set in sorted order, with outputImage and filterLoop
811 // ordered so that outputImage[i] pairs with filterLoop[i].
812 ConvolutionDimensions dimensions{
813 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
814 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
815 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
816 /*filterLoop=*/SmallVector<unsigned, 2>{},
817 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
818 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
819 /*strides=*/SmallVector<int64_t, 2>{},
820 /*dilations=*/SmallVector<int64_t, 2>{}};
821 llvm::sort(dimensions.batch);
822 llvm::sort(dimensions.outputImage);
823 llvm::sort(dimensions.outputChannel);
824 llvm::sort(dimensions.inputChannel);
825 llvm::sort(dimensions.depth);
826 // Order filterLoop to match the pairing with outputImage. Each outputImage
827 // dimension has a corresponding filterLoop dimension from the convolution
828 // access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
829 // filterLoop[i].
830 for (unsigned oiDim : dimensions.outputImage)
831 dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
832
833 // Use the op carried strides/dilations attribute if present.
834 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
835 if (!nativeStrides) {
836 SmallVector<AffineExpr, 2> strideExprs;
837 for (unsigned oiDim : dimensions.outputImage)
838 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
839 dimensions.strides = getConstantsFromExprList(strideExprs);
840 } else {
841 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
842 }
843 auto nativeDilations =
844 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
845 if (!nativeDilations) {
846 SmallVector<AffineExpr, 2> dilationExprs;
847 for (unsigned flDim : dimensions.filterLoop)
848 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
849 dimensions.dilations = getConstantsFromExprList(dilationExprs);
850 } else {
851 dimensions.dilations =
852 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
853 }
854 return dimensions;
855}
856
857/// Find at least 1 parallel (output_image) and reduction (filter_loop)
858/// dimension candidates that form a convolution subcomputation within
859/// `linalgOp`. The LHS is assumed to be the convolution input while the
860/// RHS is assumed as the filter.
861/// These dimensions are such that:
862/// 1. Optional batch dimensions that appear in the input and filter.
863/// 2. The output_image dimension is involved in a cross-correlation along LHS
864/// (i.e. it is a permutation on RES and LHS and has an associated
865/// filter_loop in RHS).
866/// 3. Optional output_channel dimension is involved in an outer-product along
867/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
868/// LHS).
869/// 4. Optional input_channel dimension appears as a permutation on LHS and
870/// RHS.
871/// 5. The filter_loop dimension appears as a permutation on the RHS and
872/// represents the shape of the kernel cross-correlated along a
873/// corresponding output_image dim.
874/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
875/// 7. All dimensions appear only once in any given indexing map.
876/// This allows e.g. detecting that some convolution is embedded within
877/// `linalgOp` with some orthogonal heuristic.
878///
879/// The `outputImage` and `filterLoop` arrays are ordered such that
880/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
881/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
882/// dimension 2). Other dimension sets are returned in sorted order.
883///
884/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
885FailureOr<ConvolutionDimensions>
887 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
888 return failure();
889
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
891
892 // Check the input indexing map has the right form.
893 ConvAccessExprWalker inputExprWalker;
894 for (AffineExpr expr : indexingMaps[0].getResults())
895 (void)inputExprWalker.visit(expr);
896 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
897
898 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
899 /*allowEmptyConvolvedDims=*/false);
900}
901
902namespace mlir::linalg::detail {
914} // namespace mlir::linalg::detail
915
918 Operation *op, ConvolutionDimensions *dimensions,
919 bool allowEmptyConvolvedDims) {
920 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
921 if (!linalgOp)
923 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
925
926 auto indexingMaps = linalgOp.getIndexingMapsArray();
927
928 // Check the input indexing map has the right form.
929 ConvAccessExprWalker inputExprWalker;
930 if (llvm::any_of(indexingMaps[0].getResults(),
931 [&inputExprWalker](AffineExpr expr) {
932 return failed(inputExprWalker.visit(expr));
933 })) {
935 }
936
937 // Filter and output maps must be projected permutation.
938 if (!indexingMaps[1].isProjectedPermutation() ||
939 !indexingMaps.back().isProjectedPermutation())
941
942 auto iteratorTypes = linalgOp.getIteratorTypesArray();
943
944 llvm::SmallDenseSet<int64_t> outputDims =
945 getPreservedDims(indexingMaps.back());
946 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
947 // Make sure all loops are characterized as one of:
948 // - Batch loop : present in output, as non-convolved in input, not present in
949 // filter.
950 // - Output image dimension : present in output, convolved dims in input, not
951 // present in filter.
952 // - Output channel dimension : present in output, not present in input,
953 // present in filter.
954 // - Filter loop dimension : present in filter, convolved in input, not
955 // present in output.
956 // - Input channel dimension : unconvolved in input, not present in output,
957 // present in filter.
958 // - Depth multiplier : unconvolved in input, present in output, present in
959 // filter.
960 llvm::SmallDenseSet<int64_t> allLoopDims;
961 for (auto outputExpr : indexingMaps.back().getResults()) {
962 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
963 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
964 !filterDims.count(outputDim)) {
965 // Batch dimension.
966 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
968 allLoopDims.insert(outputDim);
969 continue;
970 }
971 if (inputExprWalker.convolvedDims.count(outputDim) &&
972 !filterDims.count(outputDim)) {
973 // Output image Loop dimension.
974 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
976 allLoopDims.insert(outputDim);
977 continue;
978 }
979 if (!inputExprWalker.convolvedDims.count(outputDim) &&
980 !inputExprWalker.unConvolvedDims.count(outputDim) &&
981 filterDims.count(outputDim)) {
982 // Output channel dimension.
983 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
985 allLoopDims.insert(outputDim);
986 continue;
987 }
988 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
989 filterDims.count(outputDim)) {
990 // Depth multiplier.
991 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
993 allLoopDims.insert(outputDim);
994 continue;
995 }
997 }
998 for (auto filterExpr : indexingMaps[1].getResults()) {
999 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1000 if (outputDims.count(filterDim) &&
1001 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1002 !inputExprWalker.convolvedDims.count(filterDim)) {
1003 // Output channel dimension. This is already seen, continue;
1004 continue;
1005 }
1006 if (inputExprWalker.convolvedDims.count(filterDim) &&
1007 !outputDims.count(filterDim)) {
1008 // Filter loop dimension.
1009 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1011 if (allLoopDims.count(filterDim))
1013 allLoopDims.insert(filterDim);
1014 continue;
1015 }
1016 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1017 !outputDims.count(filterDim)) {
1018 // Input channel dimension.
1019 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1021 if (allLoopDims.count(filterDim))
1023 allLoopDims.insert(filterDim);
1024 continue;
1025 }
1026 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1027 outputDims.count(filterDim)) {
1028 // Depthwise loop. Already seen.
1029 continue;
1030 }
1032 }
1033 // All loops must be covered now.
1034 if (allLoopDims.size() != linalgOp.getNumLoops())
1036
1037 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1039
1040 if (dimensions) {
1041 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1042 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1043 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1044 *dimensions = *res;
1045 }
1046
1048}
1049
1050StringRef
1052 switch (res) {
1054 return "expected a LinalgOp";
1056 return "expected op with 2 inputs and 1 output";
1058 return "unexpected input index map for convolutions";
1060 return "expected output/filter indexing maps to be projected permutations";
1062 return "unexpected loop dimension for convolution op";
1064 return "expected all iterators used to access outputs to be parallel";
1066 return "expected all iterators not used to access outputs to be reduction";
1068 return "expected convolved dim to be non-empty";
1070 return "";
1071 }
1072 llvm_unreachable("unhandled MatchConvolutionResult case");
1073}
1074
1076 bool allowEmptyConvolvedDims) {
1078 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1080}
1081
1088
1089//===----------------------------------------------------------------------===//
1090// FillOpInterface implementation
1091//===----------------------------------------------------------------------===//
1092
1093namespace {
1094enum class MatchFillResult {
1095 Success = 0,
1096 NotLinalgOp,
1097 WrongNumOperands,
1098 NotScalarInput,
1099 TypeMismatch
1100};
1101} // namespace
1102
1103static MatchFillResult isFillInterfaceImpl(Operation *op) {
1104 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1105 if (!linalgOp)
1106 return MatchFillResult::NotLinalgOp;
1107 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1108 return MatchFillResult::WrongNumOperands;
1109
1110 OpOperand *value = linalgOp.getDpsInputOperand(0);
1111 if (!linalgOp.isScalar(value))
1112 return MatchFillResult::NotScalarInput;
1113
1114 // Check that the scalar input type matches the output element type.
1115 OpOperand *output = linalgOp.getDpsInitOperand(0);
1116 Type scalarType = value->get().getType();
1117 Type outputElementType = getElementTypeOrSelf(output->get().getType());
1118 if (scalarType != outputElementType)
1119 return MatchFillResult::TypeMismatch;
1120
1121 return MatchFillResult::Success;
1122}
1123
1125 MatchFillResult res = isFillInterfaceImpl(op);
1126 if (res == MatchFillResult::NotLinalgOp)
1127 return op->emitError("expected a LinalgOp");
1128 if (res == MatchFillResult::WrongNumOperands)
1129 return op->emitError("expected op with 1 input and 1 output");
1130 if (res == MatchFillResult::NotScalarInput)
1131 return op->emitError("expected op with scalar input");
1132 if (res == MatchFillResult::TypeMismatch) {
1133 auto linalgOp = cast<linalg::LinalgOp>(op);
1134 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1135 Type outputElementType =
1136 getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1137 return op->emitOpError("expected fill value type (")
1138 << scalarType << ") to match output element type ("
1139 << outputElementType << ")";
1140 }
1141
1142 return success();
1143}
1144
1145//===----------------------------------------------------------------------===//
1146// StructuredOpInterface implementation
1147//===----------------------------------------------------------------------===//
1148
1149SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1150 Location loc) {
1152 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1153 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1154 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1155 }
1156 return res;
1157}
1158
1159SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1161 assert(!hasDynamicShape() && "expected operands to have static shapes");
1162 for (OpOperand &opOperand : getOperation()->getOpOperands())
1163 llvm::append_range(res, getShape(&opOperand));
1164 return res;
1165}
1166
1167SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1168 AffineMap map = getLoopsToShapesMap();
1169 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1170 auto viewSizes = createFlatListOfOperandDims(b, loc);
1171 SmallVector<Range, 4> res(numDims);
1172 for (unsigned idx = 0; idx < numRes; ++idx) {
1173 auto result = map.getResult(idx);
1174 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1175 if (res[d.getPosition()].offset)
1176 continue;
1177 res[d.getPosition()] =
1178 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1179 }
1180 }
1181 return res;
1182}
1183
1184/// Visitor to check if any of the given set of positions from AffineDimExprs
1185/// are used within an AffineExpr.
1187 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1188 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1189 : positions(std::move(positions)) {}
1190
1192 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1193 }
1194
1196 return positions.test(dimExpr.getPosition());
1197 }
1198
1199 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1200
1201 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1202
1203private:
1204 llvm::SmallBitVector positions;
1205};
1206
1207static std::pair<int64_t, int64_t>
1209 int64_t inputRankSum = 0;
1210 int64_t outputRankSum = 0;
1211 for (OpOperand *input : op.getDpsInputOperands())
1212 inputRankSum += op.getRank(input);
1213 for (OpOperand &output : op.getDpsInitsMutable())
1214 outputRankSum += op.getRank(&output);
1215 return {inputRankSum, inputRankSum + outputRankSum};
1216}
1217
1218LogicalResult
1219LinalgOp::reifyResultShapes(OpBuilder &b,
1220 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1221 // An example that helps understand the logic below.
1222 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1223 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1224 // This is achieved as follows.
1225 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1226 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1227 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1228 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1229 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1230 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1231
1232 // Find the position in the above map that represents the shape of the
1233 // result:dim being inferred.
1234 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1235
1236 /// From loopsToShapesMap extract the submap that represents the shape of the
1237 /// (resultIdx, dim) needed.
1238 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1239 resultShapesSubMapPos.first,
1240 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1241 AffineMap resultShapesFromInputShapesMap =
1242 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1243
1244 // Check that the result dim map does not contain the positions corresponding
1245 // to the outputs.
1246 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1247 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1248 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1249 Location loc = getOperation()->getLoc();
1250 IRRewriter rewriter(b);
1251 SmallVector<OpFoldResult> allResultDimValues =
1252 affine::makeComposedFoldedMultiResultAffineApply(
1253 rewriter, loc, resultShapesFromInputShapesMap,
1254 createFlatListOfOperandDims(b, loc));
1255 int64_t pos = 0;
1256 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1257 for (OpOperand &opOperand : getDpsInitsMutable()) {
1258 SmallVector<OpFoldResult> shapes;
1259 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1260 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1261 if (!shapedType.isDynamicDim(dim)) {
1262 // Static dim: Return IntegerAttr.
1263 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1264 } else {
1265 // Dynamic dim: Return Value.
1266 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1267 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1268 : allResultDimValues[pos];
1269 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1270 }
1271 pos++;
1272 }
1273 reifiedReturnShapes.emplace_back(std::move(shapes));
1274 }
1275 return success();
1276}
1277
1278/// Return the index in the indexingMaps vector that corresponds to this
1279/// `opOperand`.
1280int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1281 auto operandNumber = opOperand->getOperandNumber();
1282 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1283 if (!dpsIface.isDpsInput(opOperand))
1284 return operandNumber;
1285 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1286 assert(!dpsIface.isDpsInit(opOperand));
1287 // Account for potential inputs that are not DPS and may not appear in
1288 // `indexingMaps`.
1289 return cast<DestinationStyleOpInterface>(*this->getOperation())
1290 .getNumDpsInputs() +
1291 operandNumber - start;
1292}
1293
1295 LinalgOp linalgOp = cast<LinalgOp>(op);
1296 // Mixed tensor/buffer operands are not allowed.
1297 if (!linalgOp.hasPureTensorSemantics() &&
1298 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1299 return op->emitOpError("expected to have pure tensor or buffer semantics");
1300
1301 // Before checking indexing maps, we need to make sure the attributes
1302 // referenced by it are valid.
1303 if (linalgOp.hasDynamicIndexingMaps())
1304 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1305 return failure();
1306
1307 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1308 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1309 return failure();
1310
1311 // Set this flag if this op has user defined maps. This is required to guard
1312 // the below error condition which assume default indexing maps.
1313 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1314 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1315 // Domain must be consistent.
1316 unsigned numLoops = linalgOp.getNumLoops();
1317 if (indexingMap.getNumDims() != numLoops)
1318 return op->emitOpError("expected indexing_map #")
1319 << opOperand.getOperandNumber() << " to have " << numLoops
1320 << " dim(s) to match the number of loops";
1321 }
1322 SmallVector<unsigned> redDims;
1323 linalgOp.getReductionDims(redDims);
1324
1325 if (!linalgOp.getShapesToLoopsMap())
1326 return op->emitOpError("expected the shape-to-loops map to be non-null");
1327
1328 // Check the region has exactly one block.
1329 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1330 return op->emitOpError("expects to have 1 region with 1 block");
1331
1332 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1333 // types.
1334 // TODO: once ranked shape types are plugged in, we may want to drop the
1335 // corresponding bbargs, that can never be read from. This will be subject to
1336 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1337 // not used).
1338 Block &block = linalgOp->getRegion(0).front();
1339
1340 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1341 return op->emitOpError("expected as many non-induction variable region "
1342 "arguments as the number of input/output operands");
1343
1344 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1345 Type elementType = opOperand->get().getType();
1346 if (isa<MemRefType, RankedTensorType>(elementType))
1347 elementType = getElementTypeOrSelf(opOperand->get().getType());
1348 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1349 if (elementType != argType)
1350 return op->emitOpError("expected type of bb argument #")
1351 << opOperand->getOperandNumber() << " (" << argType << ")"
1352 << " to match element or self type of the corresponding operand ("
1353 << elementType << ")";
1354 }
1355
1356 return success();
1357}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:783
unsigned getNumOperands()
Definition Operation.h:372
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop