MLIR 23.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 if (auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
143 return SmallVector<int64_t>(broadcastOp.getDimensions().begin(),
144 broadcastOp.getDimensions().end());
145
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
147 if (!op)
148 return std::nullopt;
149
150 // Structural.
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
153 return std::nullopt;
154
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
159 return std::nullopt;
160
161 // Check output is identity map. Broadcast could additionally be
162 // employing permutation of indices and that would be expressible
163 // in linalg.generic but is not expressible for named broadcast op.
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
166 return std::nullopt;
167
168 SmallVector<int64_t> position;
169 auto srcMap = op.getIndexingMapsArray()[0];
170
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
172 return std::nullopt;
173
174 // Check input map is monotonically increasing DimIds.
175 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
177 if (!expr)
178 return std::nullopt;
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
181 return std::nullopt;
182 position.push_back(expr.getPosition());
183 }
184
185 SmallVector<int64_t> broadcastedDims;
186 auto numDims = srcMap.getNumDims();
187 // This is quadratic but number of items is generally small.
188 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
191 }
192 return broadcastedDims;
193}
194
195//===----------------------------------------------------------------------===//
196// TransposeOpInterface implementation
197//===----------------------------------------------------------------------===//
198std::optional<SmallVector<int64_t>>
200 // To specialize as a transpose op, the genericOp must be
201 // all parallel loops, single input, single output, and its body
202 // should be just a yield op, yielding input as output as is (no compute).
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
205 return std::nullopt;
206
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
209 return std::nullopt;
210
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
213
214 // linalg.transpose permutes the dimensions of input using this
215 // rule: dim(result, i) = dim(input, permutation[i])
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
217 return std::nullopt;
218
219 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
220 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
223 }
224 return permutation;
225}
226
227//===----------------------------------------------------------------------===//
228// Elementwise Single Unary/Binary-OpInterface implementation
229//===----------------------------------------------------------------------===//
230static bool
231isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
232 bool allowNonIdentityMaps) {
233 // Check all loops are parallel.
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
235 return false;
236
237 // Check there are arity-inputs, 1-output and all are identity-maps (unless
238 // requested otherwise).
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](AffineMap map) { return map.isIdentity(); })))
243 return false;
244
245 // Init should not be referenced for elementwise operations.
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
247 return false;
248
249 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
250 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
251 // the body, where the first is the elementwise single op and the second a
252 // yield.
253 Block *body = op.getBody();
254 if (body->getOperations().size() != 2)
255 return false;
256
257 // The payload op must have one result and at least arity-many operands
258 // (otherwise not all inputs can be used). It can have additional operands
259 // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
260 // use an input more than once (e.g. mul(x, x) for linalg.square).
261 Operation *oper = &body->front();
262 if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
263 return false;
264
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
268}
269
270bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
271 bool allowNonIdentityMaps) {
272 // All basic elemwise checks.
273 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
274 return false;
275
276 // Check input is actually used.
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
278 return false;
279 return true;
280}
281
282bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
283 bool allowNonIdentityMaps) {
284 // All basic elemwise checks.
285 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
286 return false;
287
288 // Check both inputs are used (elementwise).
289 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
290 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
291 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
292 !op.payloadUsesValueFromOperand(inputOpOperand1));
293}
294
295//===----------------------------------------------------------------------===//
296// ContractionOpInterface implementation
297//===----------------------------------------------------------------------===//
298
299/// If the value is defined by a chain of unary side effect-free, go up the
300/// use-def chain until the first value that isn't defined by such an op.
301// TODO: relax to multi-operands with constants, which are technically unary ops
302// as needed (e.g. add5).
304 Operation *op = value.getDefiningOp();
305 while (op && op->getNumOperands() == 1) {
306 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
307 if (!iface || !iface.hasNoEffect())
308 break;
309 value = op->getOperand(0);
310 op = value.getDefiningOp();
311 }
312 return value;
313}
314
316 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
317 llvm::raw_ostream &errs) {
318 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
319 errs << "no terminator in the block";
320 return false;
321 }
322
323 if (block.getNumArguments() != 3) {
324 errs << "expected block with 3 arguments";
325 return false;
326 }
327
328 Operation *terminator = block.getTerminator();
329 if (terminator->getNumOperands() != 1) {
330 errs << "expected terminator with 1 operand";
331 return false;
332 }
333
334 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
335 Operation *reductionOp = yielded.getDefiningOp();
336 if (!reductionOp || reductionOp->getNumResults() != 1 ||
337 reductionOp->getNumOperands() != 2) {
338 errs << "expected reduction op to be binary";
339 return false;
340 }
341
342 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
343 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
344
345 if (reductionLHS != block.getArgument(2) &&
346 reductionRHS != block.getArgument(2)) {
347 errs << "expected reduction to take block argument #2 as one of the "
348 "operands (modulo unary casts)";
349 return false;
350 }
351
352 Value contributed = getSourceSkipUnary(
353 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
354 Operation *elementwiseOp = contributed.getDefiningOp();
355 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
356 elementwiseOp->getNumOperands() != 2) {
357 errs << "expected elementwise op to be binary";
358 return false;
359 }
360
361 if (!isaPair(elementwiseOp, reductionOp)) {
362 errs << "expected reduction/elementwise op kind not satisfied";
363 return false;
364 }
365
366 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
367 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
368 if ((elementwiseLHS == block.getArgument(0) &&
369 elementwiseRHS == block.getArgument(1)) ||
370 (elementwiseLHS == block.getArgument(1) &&
371 elementwiseRHS == block.getArgument(0))) {
372 return true;
373 }
374
375 errs << "expected elementwise op to apply to block arguments (modulo unary "
376 "casts)";
377 return false;
378}
379
380/// Returns true if the two operations are of the kinds specified by a pair of
381/// consecutive template arguments.
382template <typename AddOpTy, typename MulOpTy, typename... Args>
384 static_assert(sizeof...(Args) % 2 == 0,
385 "expected an even number of template arguments");
386 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
387 return true;
388
389 if constexpr (sizeof...(Args) > 0)
391 else
392 return false;
393}
394
395/// Returns true if the block is a body of a contraction with the kinds of
396/// operations given pairwise by template arguments.
397template <typename... Args>
401
402/// Given an `indexingMap` and its corresponding `iterators`, returns
403/// the positions of the iterators of type `iter` that are indexed by
404/// the `indexingMap` as a permutation. This is useful to infer various
405/// subcomputations on a `LinalgOp`. This is performed by looking up
406/// each result in the `indexingMap` and determining whether:
407/// - It is a single AffineDimExpr.
408/// - It is the only result involving this AffineDimExpr.
409static llvm::SmallDenseSet<int64_t>
412 utils::IteratorType iter) {
413 assert(iterators.size() == indexingMap.getNumDims());
414 llvm::SmallDenseSet<int64_t> res;
415 for (AffineExpr e : indexingMap.getResults()) {
416 if (auto d = dyn_cast<AffineDimExpr>(e)) {
417 if (iterators[d.getPosition()] == iter &&
418 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
419 return e.isFunctionOfDim(d.getPosition());
420 }) == 1)
421 res.insert(d.getPosition());
422 }
423 }
424 return res;
425}
426
427namespace {
428auto par = utils::IteratorType::parallel;
429auto red = utils::IteratorType::reduction;
430} // namespace
431
432/// Infer the iterator types from the init affine map. This looks at which dims
433/// are present in the map results, and returns an iterator types array with
434/// parallel types for dims that are present, and reduction types for dims that
435/// are not present.
436static FailureOr<SmallVector<utils::IteratorType>>
438 if (!map.isProjectedPermutation())
439 return failure();
440 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
441 for (auto expr : map.getResults())
442 if (auto dim = dyn_cast<AffineDimExpr>(expr))
443 iterators[dim.getPosition()] = par;
444 return iterators;
445}
446
447/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
448/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
449/// 1. The m dimension is involved in an outer-product along LHS
450/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
451/// 2. The n dimension is involved in an outer-product along RHS
452/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
453/// 3. The k dimension appears as a permutation on LHS and RHS.
454/// 4. m, n and k appear only once in any given indexing.
455/// 5. Optional batch dimensions that appear in all operands are captured.
456/// This allows e.g. detecting that some contraction is embedded within
457/// `linalgOp` with some orthogonal heuristic.
458static FailureOr<ContractionDimensions>
461 llvm::SmallDenseSet<int64_t> a =
462 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
463 llvm::SmallDenseSet<int64_t> b =
464 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
465 llvm::SmallDenseSet<int64_t> c =
466 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
467
468 // A & C - B are the iterators involved in an outer-product along A (the LHS).
469 llvm::SmallDenseSet<int64_t> ac = a;
470 llvm::set_intersect(ac, c);
471 llvm::set_subtract(ac, b);
472 // B & C - A are the iterators involved in an outer-product along B (the RHS).
473 llvm::SmallDenseSet<int64_t> bc = b;
474 llvm::set_intersect(bc, c);
475 llvm::set_subtract(bc, a);
476 // A & B & C are the "batch" dimensions.
477 llvm::SmallDenseSet<int64_t> batches = a;
478 llvm::set_intersect(batches, b);
479 llvm::set_intersect(batches, c);
480
481 // A & B red are the reduction dimensions.
482 llvm::SmallDenseSet<int64_t> ra =
483 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
484 llvm::SmallDenseSet<int64_t> rb =
485 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
486 llvm::set_intersect(ra, rb);
487
488 // Return each set in sorted order.
489 ContractionDimensions dimensions{
490 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
491 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
492 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
493 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
494 llvm::sort(dimensions.batch);
495 llvm::sort(dimensions.m);
496 llvm::sort(dimensions.n);
497 llvm::sort(dimensions.k);
498 return dimensions;
499}
500
501FailureOr<ContractionDimensions>
503 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
504 return failure();
505 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
506 linalgOp.getIteratorTypesArray());
507}
508
509FailureOr<ContractionDimensions>
511 if (indexingMaps.size() != 3)
512 return failure();
513 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
514 if (failed(iterators))
515 return failure();
516 return inferContractionDimsImpl(indexingMaps, iterators.value());
517}
518
519namespace mlir::linalg::detail {
528} // namespace mlir::linalg::detail
529
533 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
534 if (!linalgOp)
536 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
538 auto mapRange = linalgOp.getIndexingMapsArray();
539 if (linalgOp.getNumReductionLoops() == 0)
541 if (llvm::any_of(mapRange,
542 [](AffineMap m) { return !m.isProjectedPermutation(); }))
544 // TODO: more fields than add/mul.
545 // clang-format off
547 arith::MulFOp, arith::AddFOp,
548 arith::MulIOp, arith::AddIOp,
549 complex::MulOp, complex::AddOp,
550 arith::AndIOp, arith::OrIOp>(
551 *linalgOp.getBlock())) {
553 }
554 // clang-format on
555
556 if (dimensions) {
557 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
558 assert(succeeded(res) && "unexpected failure to infer contraction dims");
559 *dimensions = *res;
560 }
562}
563
564StringRef
566 switch (res) {
568 return "expected a LinalgOp";
570 return "expected op with 2 inputs and 1 output";
572 return "expected at least 1 reduction";
574 return "expected indexing maps to be projected permutations";
576 return "expected add/mul op in the body";
578 return "";
579 }
580 llvm_unreachable("unhandled MatchContractionResult case");
581}
582
584 if (!linalgOp)
585 return false;
586 Operation *op = linalgOp.getOperation();
587 return isa<ContractionOpInterface>(op) ||
590}
591
592/// Verify that a LinalgOp `op` is a contraction.
593/// A Linalg contraction is defined in general terms:
594/// 1. Has 2 input and 1 output shapes.
595/// 2. Has at least one reduction dimension.
596/// 3. Has only projected permutation indexing maps.
597/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
598/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
599/// operations that may change the type (e.g. for mixed-precision).
600/// As a consequence, when vectorization of such an op occurs, the only special
601/// behavior is that the (unique) MulOpType is vectorized into a
602/// `vector.contract`. All other ops are handled in a generic fashion.
603/// In the future, we may wish to allow more input arguments and elementwise and
604/// constant operations that do not involve the reduction dimension(s).
611
612//===----------------------------------------------------------------------===//
613// ConvolutionOpInterface implementation
614//===----------------------------------------------------------------------===//
615
616/// Of the given two expressions returns one that is of type T (`lhs` gets
617/// preference over `rhs`)
618template <typename T>
620 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
621}
622
623namespace {
624/// Walk the indexing expressions for input of a convolution operation to verify
625/// its of the right form, either
626/// - AffineDimExpr
627/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
628/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
629///
630/// classifies the AffineDimExpr as convolved dimensions or unconvolved
631/// dimensions and verifies each dimension occurs only once.
632struct ConvAccessExprWalker
633 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
634 // Stores dimensions used in expressions of the above form.
635 llvm::SmallDenseSet<int64_t> convolvedDims;
636 // Stores the dual mapping between LHS and RHS of convolution exprs.
637 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
638 // Stores single use dimensions used by an AffineDimExpr.
639 llvm::SmallDenseSet<int64_t> unConvolvedDims;
640 // Stores a mapping from convolved dims to their coefficient.
641 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
642
643 // Removes dims with multiple uses in the source input map from dimension
644 // sets tracked by this walker.
645 void clearMultiUseDims(AffineMap map) {
646 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
647 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
648 return e.isFunctionOfDim(dimPos);
649 }) > 1) {
650 convolvedDims.erase(dimPos);
651 unConvolvedDims.erase(dimPos);
652 // If a duplicate dim is marked as convolved, the pair of the duplicate
653 // dim must be removed from the map as well.
654 auto it = convolvedDimMapping.find(dimPos);
655 if (it != convolvedDimMapping.end()) {
656 int64_t pairedDim = it->second;
657 convolvedDims.erase(pairedDim);
658 unConvolvedDims.erase(pairedDim);
659 strideAndDilationMapping.erase(pairedDim);
660 convolvedDimMapping.erase(dimPos);
661 convolvedDimMapping.erase(pairedDim);
662 }
663 }
664 }
665 }
666
667 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
668 unsigned position = dimExpr.getPosition();
669 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
670 return failure();
671 }
672 unConvolvedDims.insert(position);
673 return success();
674 }
675
676 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
677
678 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
679
680 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
681 // In pre-order visit, top level op has to be an add op.
682 if (binaryExpr.getKind() != AffineExprKind::Add)
683 return failure();
684 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
685 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
686 if (failed(lhsDimPos) || failed(rhsDimPos))
687 return failure();
688 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
689 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
690 return success();
691 }
692
693 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
694 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
695 int64_t dim = dimExpr.getPosition();
696 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
697 return failure();
698 // Stride/dilation for this dim is implicitly 1.
699 strideAndDilationMapping[dim] =
701 convolvedDims.insert(dim);
702 return dim;
703 }
704 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
705 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
706 return failure();
707 auto lhsExpr = symbolMulExpr.getLHS();
708 auto rhsExpr = symbolMulExpr.getRHS();
709 // Check for symbol expression.
710 AffineExpr mulExpr =
712 // If there was no symbol expr, check for constant expression.
713 if (!mulExpr) {
714 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
715 }
716 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
717 if (!mulExpr || !dimExpr)
718 return failure();
719 int64_t dim = dimExpr.getPosition();
720 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
721 return failure();
722 strideAndDilationMapping[dim] = mulExpr;
723 convolvedDims.insert(dim);
724 return dim;
725 }
726 return failure();
727 }
728};
729} // namespace
730
731static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
732 assert(map.isProjectedPermutation() &&
733 "expected map to have projected permutations");
734 llvm::SmallDenseSet<int64_t> preservedDims;
735 for (auto expr : map.getResults())
736 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
737 return preservedDims;
738}
739
743 for (auto e : exprs) {
744 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
745 assert(constantExpr && "Found non-constant stride/dilation");
746 vals.push_back(constantExpr.getValue());
747 }
748 return vals;
749}
750
751/// Classifies dimensions in the `linalgOp` used by a convolution
752/// subcomputation, as captured by `inputExprWalker`. If
753/// `allowEmptyConvolvedDims` is not set this will fail if there is not
754/// at least one convolved dimension pair (output image + filter loop).
755///
756/// The returned dimensions are ordered as follows:
757/// - `outputImage` is sorted by dimension index.
758/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
759/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
760/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
761/// - `strides[i]` corresponds to `outputImage[i]`.
762/// - `dilations[i]` corresponds to `filterLoop[i]`.
763/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
764static FailureOr<ConvolutionDimensions>
765inferConvolutionDimsImpl(LinalgOp linalgOp,
766 ConvAccessExprWalker &inputExprWalker,
767 bool allowEmptyConvolvedDims) {
768 auto filterMap =
769 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
770 auto outputMap =
771 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
772 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
773 filterMap, linalgOp.getIteratorTypesArray(), par);
774 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
775 outputMap, linalgOp.getIteratorTypesArray(), par);
776
777 // unConvolvedDims & outputDims - filterDims are the batch iterators.
778 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
779 llvm::set_intersect(batch, outputDims);
780 llvm::set_subtract(batch, filterDims);
781
782 // convolvedDims & outputDims are the output image iterators.
783 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
784 llvm::set_intersect(oi, outputDims);
785
786 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
787 llvm::SmallDenseSet<int64_t> oc = filterDims;
788 llvm::set_intersect(oc, outputDims);
789 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
790
791 // filterDims & outputDims & unConvolvedDims are the depth iterators.
792 llvm::SmallDenseSet<int64_t> depth = filterDims;
793 llvm::set_intersect(depth, outputDims);
794 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
795
796 llvm::SmallDenseSet<int64_t> filterReducedDims =
798 linalgOp.getIteratorTypesArray(), red);
799
800 // convolvedDims & filterReducedDims are the filter loop iterators.
801 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
802 llvm::set_intersect(fl, filterReducedDims);
803
804 // unConvolvedDims & filterReducedDims are the input channel iterators.
805 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
806 llvm::set_intersect(ic, filterReducedDims);
807
808 if (oi.empty() && !allowEmptyConvolvedDims)
809 return failure();
810
811 // Return each set in sorted order, with outputImage and filterLoop
812 // ordered so that outputImage[i] pairs with filterLoop[i].
813 ConvolutionDimensions dimensions{
814 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
815 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
816 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
817 /*filterLoop=*/SmallVector<unsigned, 2>{},
818 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
819 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
820 /*strides=*/SmallVector<int64_t, 2>{},
821 /*dilations=*/SmallVector<int64_t, 2>{}};
822 llvm::sort(dimensions.batch);
823 llvm::sort(dimensions.outputImage);
824 llvm::sort(dimensions.outputChannel);
825 llvm::sort(dimensions.inputChannel);
826 llvm::sort(dimensions.depth);
827 // Order filterLoop to match the pairing with outputImage. Each outputImage
828 // dimension has a corresponding filterLoop dimension from the convolution
829 // access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
830 // filterLoop[i].
831 for (unsigned oiDim : dimensions.outputImage)
832 dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
833
834 // Use the op carried strides/dilations attribute if present.
835 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
836 if (!nativeStrides) {
837 SmallVector<AffineExpr, 2> strideExprs;
838 for (unsigned oiDim : dimensions.outputImage)
839 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
840 dimensions.strides = getConstantsFromExprList(strideExprs);
841 } else {
842 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
843 }
844 auto nativeDilations =
845 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
846 if (!nativeDilations) {
847 SmallVector<AffineExpr, 2> dilationExprs;
848 for (unsigned flDim : dimensions.filterLoop)
849 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
850 dimensions.dilations = getConstantsFromExprList(dilationExprs);
851 } else {
852 dimensions.dilations =
853 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
854 }
855 return dimensions;
856}
857
858/// Find at least 1 parallel (output_image) and reduction (filter_loop)
859/// dimension candidates that form a convolution subcomputation within
860/// `linalgOp`. The LHS is assumed to be the convolution input while the
861/// RHS is assumed as the filter.
862/// These dimensions are such that:
863/// 1. Optional batch dimensions that appear in the input and filter.
864/// 2. The output_image dimension is involved in a cross-correlation along LHS
865/// (i.e. it is a permutation on RES and LHS and has an associated
866/// filter_loop in RHS).
867/// 3. Optional output_channel dimension is involved in an outer-product along
868/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
869/// LHS).
870/// 4. Optional input_channel dimension appears as a permutation on LHS and
871/// RHS.
872/// 5. The filter_loop dimension appears as a permutation on the RHS and
873/// represents the shape of the kernel cross-correlated along a
874/// corresponding output_image dim.
875/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
876/// 7. All dimensions appear only once in any given indexing map.
877/// This allows e.g. detecting that some convolution is embedded within
878/// `linalgOp` with some orthogonal heuristic.
879///
880/// The `outputImage` and `filterLoop` arrays are ordered such that
881/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
882/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
883/// dimension 2). Other dimension sets are returned in sorted order.
884///
885/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
886FailureOr<ConvolutionDimensions>
888 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
889 return failure();
890
891 auto indexingMaps = linalgOp.getIndexingMapsArray();
892
893 // Check the input indexing map has the right form.
894 ConvAccessExprWalker inputExprWalker;
895 for (AffineExpr expr : indexingMaps[0].getResults())
896 (void)inputExprWalker.visit(expr);
897 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
898
899 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
900 /*allowEmptyConvolvedDims=*/false);
901}
902
903namespace mlir::linalg::detail {
915} // namespace mlir::linalg::detail
916
919 Operation *op, ConvolutionDimensions *dimensions,
920 bool allowEmptyConvolvedDims) {
921 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
922 if (!linalgOp)
924 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
926
927 auto indexingMaps = linalgOp.getIndexingMapsArray();
928
929 // Check the input indexing map has the right form.
930 ConvAccessExprWalker inputExprWalker;
931 if (llvm::any_of(indexingMaps[0].getResults(),
932 [&inputExprWalker](AffineExpr expr) {
933 return failed(inputExprWalker.visit(expr));
934 })) {
936 }
937
938 // Filter and output maps must be projected permutation.
939 if (!indexingMaps[1].isProjectedPermutation() ||
940 !indexingMaps.back().isProjectedPermutation())
942
943 auto iteratorTypes = linalgOp.getIteratorTypesArray();
944
945 llvm::SmallDenseSet<int64_t> outputDims =
946 getPreservedDims(indexingMaps.back());
947 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
948 // Make sure all loops are characterized as one of:
949 // - Batch loop : present in output, as non-convolved in input, not present in
950 // filter.
951 // - Output image dimension : present in output, convolved dims in input, not
952 // present in filter.
953 // - Output channel dimension : present in output, not present in input,
954 // present in filter.
955 // - Filter loop dimension : present in filter, convolved in input, not
956 // present in output.
957 // - Input channel dimension : unconvolved in input, not present in output,
958 // present in filter.
959 // - Depth multiplier : unconvolved in input, present in output, present in
960 // filter.
961 llvm::SmallDenseSet<int64_t> allLoopDims;
962 for (auto outputExpr : indexingMaps.back().getResults()) {
963 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
964 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
965 !filterDims.count(outputDim)) {
966 // Batch dimension.
967 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
969 allLoopDims.insert(outputDim);
970 continue;
971 }
972 if (inputExprWalker.convolvedDims.count(outputDim) &&
973 !filterDims.count(outputDim)) {
974 // Output image Loop dimension.
975 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
977 allLoopDims.insert(outputDim);
978 continue;
979 }
980 if (!inputExprWalker.convolvedDims.count(outputDim) &&
981 !inputExprWalker.unConvolvedDims.count(outputDim) &&
982 filterDims.count(outputDim)) {
983 // Output channel dimension.
984 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
986 allLoopDims.insert(outputDim);
987 continue;
988 }
989 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
990 filterDims.count(outputDim)) {
991 // Depth multiplier.
992 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
994 allLoopDims.insert(outputDim);
995 continue;
996 }
998 }
999 for (auto filterExpr : indexingMaps[1].getResults()) {
1000 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1001 if (outputDims.count(filterDim) &&
1002 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1003 !inputExprWalker.convolvedDims.count(filterDim)) {
1004 // Output channel dimension. This is already seen, continue;
1005 continue;
1006 }
1007 if (inputExprWalker.convolvedDims.count(filterDim) &&
1008 !outputDims.count(filterDim)) {
1009 // Filter loop dimension.
1010 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1012 if (allLoopDims.count(filterDim))
1014 allLoopDims.insert(filterDim);
1015 continue;
1016 }
1017 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1018 !outputDims.count(filterDim)) {
1019 // Input channel dimension.
1020 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1022 if (allLoopDims.count(filterDim))
1024 allLoopDims.insert(filterDim);
1025 continue;
1026 }
1027 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1028 outputDims.count(filterDim)) {
1029 // Depthwise loop. Already seen.
1030 continue;
1031 }
1033 }
1034 // All loops must be covered now.
1035 if (allLoopDims.size() != linalgOp.getNumLoops())
1037
1038 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1040
1041 if (dimensions) {
1042 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1043 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1044 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1045 *dimensions = *res;
1046 }
1047
1049}
1050
1051StringRef
1053 switch (res) {
1055 return "expected a LinalgOp";
1057 return "expected op with 2 inputs and 1 output";
1059 return "unexpected input index map for convolutions";
1061 return "expected output/filter indexing maps to be projected permutations";
1063 return "unexpected loop dimension for convolution op";
1065 return "expected all iterators used to access outputs to be parallel";
1067 return "expected all iterators not used to access outputs to be reduction";
1069 return "expected convolved dim to be non-empty";
1071 return "";
1072 }
1073 llvm_unreachable("unhandled MatchConvolutionResult case");
1074}
1075
1077 bool allowEmptyConvolvedDims) {
1079 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1081}
1082
1089
1090//===----------------------------------------------------------------------===//
1091// FillOpInterface implementation
1092//===----------------------------------------------------------------------===//
1093
1094namespace {
1095enum class MatchFillResult {
1096 Success = 0,
1097 NotLinalgOp,
1098 WrongNumOperands,
1099 NotScalarInput,
1100 TypeMismatch
1101};
1102} // namespace
1103
1104static MatchFillResult isFillInterfaceImpl(Operation *op) {
1105 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1106 if (!linalgOp)
1107 return MatchFillResult::NotLinalgOp;
1108 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1109 return MatchFillResult::WrongNumOperands;
1110
1111 OpOperand *value = linalgOp.getDpsInputOperand(0);
1112 if (!linalgOp.isScalar(value))
1113 return MatchFillResult::NotScalarInput;
1114
1115 // Check that the scalar input type matches the output element type.
1116 OpOperand *output = linalgOp.getDpsInitOperand(0);
1117 Type scalarType = value->get().getType();
1118 Type outputElementType = getElementTypeOrSelf(output->get().getType());
1119 if (scalarType != outputElementType)
1120 return MatchFillResult::TypeMismatch;
1121
1122 return MatchFillResult::Success;
1123}
1124
1126 MatchFillResult res = isFillInterfaceImpl(op);
1127 if (res == MatchFillResult::NotLinalgOp)
1128 return op->emitError("expected a LinalgOp");
1129 if (res == MatchFillResult::WrongNumOperands)
1130 return op->emitError("expected op with 1 input and 1 output");
1131 if (res == MatchFillResult::NotScalarInput)
1132 return op->emitError("expected op with scalar input");
1133 if (res == MatchFillResult::TypeMismatch) {
1134 auto linalgOp = cast<linalg::LinalgOp>(op);
1135 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1136 Type outputElementType =
1137 getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1138 return op->emitOpError("expected fill value type (")
1139 << scalarType << ") to match output element type ("
1140 << outputElementType << ")";
1141 }
1142
1143 return success();
1144}
1145
1146//===----------------------------------------------------------------------===//
1147// StructuredOpInterface implementation
1148//===----------------------------------------------------------------------===//
1149
1150SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1151 Location loc) {
1153 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1154 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1155 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1156 }
1157 return res;
1158}
1159
1160SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1162 assert(!hasDynamicShape() && "expected operands to have static shapes");
1163 for (OpOperand &opOperand : getOperation()->getOpOperands())
1164 llvm::append_range(res, getShape(&opOperand));
1165 return res;
1166}
1167
1168SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1169 AffineMap map = getLoopsToShapesMap();
1170 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1171 auto viewSizes = createFlatListOfOperandDims(b, loc);
1172 SmallVector<Range, 4> res(numDims);
1173 for (unsigned idx = 0; idx < numRes; ++idx) {
1174 auto result = map.getResult(idx);
1175 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1176 if (res[d.getPosition()].offset)
1177 continue;
1178 res[d.getPosition()] =
1179 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1180 }
1181 }
1182 return res;
1183}
1184
1185/// Visitor to check if any of the given set of positions from AffineDimExprs
1186/// are used within an AffineExpr.
1188 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1189 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1190 : positions(std::move(positions)) {}
1191
1193 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1194 }
1195
1197 return positions.test(dimExpr.getPosition());
1198 }
1199
1200 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1201
1202 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1203
1204private:
1205 llvm::SmallBitVector positions;
1206};
1207
1208static std::pair<int64_t, int64_t>
1210 int64_t inputRankSum = 0;
1211 int64_t outputRankSum = 0;
1212 for (OpOperand *input : op.getDpsInputOperands())
1213 inputRankSum += op.getRank(input);
1214 for (OpOperand &output : op.getDpsInitsMutable())
1215 outputRankSum += op.getRank(&output);
1216 return {inputRankSum, inputRankSum + outputRankSum};
1217}
1218
1219LogicalResult
1220LinalgOp::reifyResultShapes(OpBuilder &b,
1221 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1222 // An example that helps understand the logic below.
1223 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1224 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1225 // This is achieved as follows.
1226 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1227 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1228 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1229 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1230 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1231 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1232
1233 // Find the position in the above map that represents the shape of the
1234 // result:dim being inferred.
1235 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1236
1237 /// From loopsToShapesMap extract the submap that represents the shape of the
1238 /// (resultIdx, dim) needed.
1239 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1240 resultShapesSubMapPos.first,
1241 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1242 AffineMap resultShapesFromInputShapesMap =
1243 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1244
1245 // Check that the result dim map does not contain the positions corresponding
1246 // to the outputs.
1247 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1248 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1249 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1250 Location loc = getOperation()->getLoc();
1251 IRRewriter rewriter(b);
1252 SmallVector<OpFoldResult> allResultDimValues =
1254 rewriter, loc, resultShapesFromInputShapesMap,
1255 createFlatListOfOperandDims(b, loc));
1256 int64_t pos = 0;
1257 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1258 for (OpOperand &opOperand : getDpsInitsMutable()) {
1259 SmallVector<OpFoldResult> shapes;
1260 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1261 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1262 if (!shapedType.isDynamicDim(dim)) {
1263 // Static dim: Return IntegerAttr.
1264 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1265 } else {
1266 // Dynamic dim: Return Value.
1267 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1268 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1269 : allResultDimValues[pos];
1270 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1271 }
1272 pos++;
1273 }
1274 reifiedReturnShapes.emplace_back(std::move(shapes));
1275 }
1276 return success();
1277}
1278
1279/// Return the index in the indexingMaps vector that corresponds to this
1280/// `opOperand`.
1281int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1282 auto operandNumber = opOperand->getOperandNumber();
1283 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1284 if (!dpsIface.isDpsInput(opOperand))
1285 return operandNumber;
1286 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1287 assert(!dpsIface.isDpsInit(opOperand));
1288 // Account for potential inputs that are not DPS and may not appear in
1289 // `indexingMaps`.
1290 return cast<DestinationStyleOpInterface>(*this->getOperation())
1291 .getNumDpsInputs() +
1292 operandNumber - start;
1293}
1294
1296 LinalgOp linalgOp = cast<LinalgOp>(op);
1297 // Mixed tensor/buffer operands are not allowed.
1298 if (!linalgOp.hasPureTensorSemantics() &&
1299 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1300 return op->emitOpError("expected to have pure tensor or buffer semantics");
1301
1302 // Before checking indexing maps, we need to make sure the attributes
1303 // referenced by it are valid.
1304 if (linalgOp.hasDynamicIndexingMaps())
1305 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1306 return failure();
1307
1308 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1309 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1310 return failure();
1311
1312 // Set this flag if this op has user defined maps. This is required to guard
1313 // the below error condition which assume default indexing maps.
1314 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1315 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1316 // Domain must be consistent.
1317 unsigned numLoops = linalgOp.getNumLoops();
1318 if (indexingMap.getNumDims() != numLoops)
1319 return op->emitOpError("expected indexing_map #")
1320 << opOperand.getOperandNumber() << " to have " << numLoops
1321 << " dim(s) to match the number of loops";
1322 }
1323 SmallVector<unsigned> redDims;
1324 linalgOp.getReductionDims(redDims);
1325
1326 if (!linalgOp.getShapesToLoopsMap())
1327 return op->emitOpError("expected the shape-to-loops map to be non-null");
1328
1329 // Check the region has exactly one block.
1330 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1331 return op->emitOpError("expects to have 1 region with 1 block");
1332
1333 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1334 // types.
1335 // TODO: once ranked shape types are plugged in, we may want to drop the
1336 // corresponding bbargs, that can never be read from. This will be subject to
1337 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1338 // not used).
1339 Block &block = linalgOp->getRegion(0).front();
1340
1341 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1342 return op->emitOpError("expected as many non-induction variable region "
1343 "arguments as the number of input/output operands");
1344
1345 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1346 Type elementType = opOperand->get().getType();
1347 if (isa<MemRefType, RankedTensorType>(elementType))
1348 elementType = getElementTypeOrSelf(opOperand->get().getType());
1349 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1350 if (elementType != argType)
1351 return op->emitOpError("expected type of bb argument #")
1352 << opOperand->getOperandNumber() << " (" << argType << ")"
1353 << " to match element or self type of the corresponding operand ("
1354 << elementType << ")";
1355 }
1356
1357 return success();
1358}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:782
unsigned getNumOperands()
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop