MLIR 23.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/MLIRContext.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
45 continue;
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
54 indexingMaps, linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Check yield first block argument.
72 Block *body = op.getBlock();
73 if (body->getOperations().size() != 1)
74 return false;
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
77 return false;
78 return yieldOp->getOperand(0) == body->getArgument(0);
79}
80
81//===----------------------------------------------------------------------===//
82// FillOpInterface implementation
83//===----------------------------------------------------------------------===//
84/// Detects if a linalg.generic operation represents a fill with an inlined
85/// constant. If so, returns the constant value. Otherwise, returns
86/// std::nullopt.
87static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
90 return std::nullopt;
91
92 // Init should not be referenced.
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94 return std::nullopt;
95
96 Block *body = op.getBody();
97 if (body->getOperations().size() != 1)
98 return std::nullopt;
99
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
102 return std::nullopt;
103
104 Value yieldOperand = yieldOp->getOperand(0);
105 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106 !yieldOperand.getDefiningOp<complex::ConstantOp>())
107 return std::nullopt;
108
109 return yieldOperand;
110}
111
112/// Detects if a linalg.generic operation represents an external scalar input.
113/// If so, returns the constant value. Otherwise, returns std::nullopt.
114static std::optional<Value> isaExternalFillOp(GenericOp op) {
115 // Structural.
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
118 return std::nullopt;
119
120 // Input should be referenced and init should not.
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123 return std::nullopt;
124
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
127 return std::nullopt;
128 return value->get();
129}
130
131std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132 if (auto fillVal = isaInlinedFillOp(op))
133 return fillVal;
134 return isaExternalFillOp(op);
135}
136
137//===----------------------------------------------------------------------===//
138// BroadcastOpInterface implementation
139//===----------------------------------------------------------------------===//
140std::optional<SmallVector<int64_t>>
142 if (auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
143 return SmallVector<int64_t>(broadcastOp.getDimensions().begin(),
144 broadcastOp.getDimensions().end());
145
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
147 if (!op)
148 return std::nullopt;
149
150 // Structural.
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
153 return std::nullopt;
154
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
159 return std::nullopt;
160
161 // Check output is identity map. Broadcast could additionally be
162 // employing permutation of indices and that would be expressible
163 // in linalg.generic but is not expressible for named broadcast op.
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
166 return std::nullopt;
167
168 SmallVector<int64_t> position;
169 auto srcMap = op.getIndexingMapsArray()[0];
170
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
172 return std::nullopt;
173
174 // Check input map is monotonically increasing DimIds.
175 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
177 if (!expr)
178 return std::nullopt;
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
181 return std::nullopt;
182 position.push_back(expr.getPosition());
183 }
184
185 SmallVector<int64_t> broadcastedDims;
186 auto numDims = srcMap.getNumDims();
187 // This is quadratic but number of items is generally small.
188 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
191 }
192 return broadcastedDims;
193}
194
195//===----------------------------------------------------------------------===//
196// TransposeOpInterface implementation
197//===----------------------------------------------------------------------===//
198std::optional<SmallVector<int64_t>>
200 // To specialize as a transpose op, the genericOp must be
201 // all parallel loops, single input, single output, and its body
202 // should be just a yield op, yielding input as output as is (no compute).
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
205 return std::nullopt;
206
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
209 return std::nullopt;
210
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
213
214 // linalg.transpose permutes the dimensions of input using this
215 // rule: dim(result, i) = dim(input, permutation[i])
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
217 return std::nullopt;
218
219 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
220 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
223 }
224 return permutation;
225}
226
227//===----------------------------------------------------------------------===//
228// Elementwise Single Unary/Binary-OpInterface implementation
229//===----------------------------------------------------------------------===//
230static bool
231isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
232 bool allowNonIdentityMaps) {
233 // Check all loops are parallel.
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
235 return false;
236
237 // Check there are arity-inputs, 1-output and all are identity-maps (unless
238 // requested otherwise).
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](AffineMap map) { return map.isIdentity(); })))
243 return false;
244
245 // Init should not be referenced for elementwise operations.
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
247 return false;
248
249 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
250 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
251 // the body, where the first is the elementwise single op and the second a
252 // yield.
253 Block *body = op.getBody();
254 if (body->getOperations().size() != 2)
255 return false;
256
257 // The payload op must have one result and at least arity-many operands
258 // (otherwise not all inputs can be used). It can have additional operands
259 // from outside of the generic op (e.g. div(1, x) for linalg.reciprocal) or
260 // use an input more than once (e.g. mul(x, x) for linalg.square).
261 Operation *oper = &body->front();
262 if (oper->getNumOperands() < arity || oper->getNumResults() != 1)
263 return false;
264
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
268}
269
270bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
271 bool allowNonIdentityMaps) {
272 // All basic elemwise checks.
273 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
274 return false;
275
276 // Check input is actually used.
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
278 return false;
279 return true;
280}
281
282bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
283 bool allowNonIdentityMaps) {
284 // All basic elemwise checks.
285 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2, allowNonIdentityMaps))
286 return false;
287
288 // Check both inputs are used (elementwise).
289 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
290 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
291 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
292 !op.payloadUsesValueFromOperand(inputOpOperand1));
293}
294
295//===----------------------------------------------------------------------===//
296// ContractionOpInterface implementation
297//===----------------------------------------------------------------------===//
298
299/// If the value is defined by a chain of unary side effect-free, go up the
300/// use-def chain until the first value that isn't defined by such an op.
301// TODO: relax to multi-operands with constants, which are technically unary ops
302// as needed (e.g. add5).
304 Operation *op = value.getDefiningOp();
305 while (op && op->getNumOperands() == 1) {
306 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
307 if (!iface || !iface.hasNoEffect())
308 break;
309 value = op->getOperand(0);
310 op = value.getDefiningOp();
311 }
312 return value;
313}
314
316 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
317 llvm::raw_ostream &errs) {
318 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
319 errs << "no terminator in the block";
320 return false;
321 }
322
323 if (block.getNumArguments() != 3) {
324 errs << "expected block with 3 arguments";
325 return false;
326 }
327
328 Operation *terminator = block.getTerminator();
329 if (terminator->getNumOperands() != 1) {
330 errs << "expected terminator with 1 operand";
331 return false;
332 }
333
334 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
335 Operation *reductionOp = yielded.getDefiningOp();
336 if (!reductionOp || reductionOp->getNumResults() != 1 ||
337 reductionOp->getNumOperands() != 2) {
338 errs << "expected reduction op to be binary";
339 return false;
340 }
341
342 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
343 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
344
345 if (reductionLHS != block.getArgument(2) &&
346 reductionRHS != block.getArgument(2)) {
347 errs << "expected reduction to take block argument #2 as one of the "
348 "operands (modulo unary casts)";
349 return false;
350 }
351
352 Value contributed = getSourceSkipUnary(
353 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
354 Operation *elementwiseOp = contributed.getDefiningOp();
355 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
356 elementwiseOp->getNumOperands() != 2) {
357 errs << "expected elementwise op to be binary";
358 return false;
359 }
360
361 if (!isaPair(elementwiseOp, reductionOp)) {
362 errs << "expected reduction/elementwise op kind not satisfied";
363 return false;
364 }
365
366 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
367 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
368 if ((elementwiseLHS == block.getArgument(0) &&
369 elementwiseRHS == block.getArgument(1)) ||
370 (elementwiseLHS == block.getArgument(1) &&
371 elementwiseRHS == block.getArgument(0))) {
372 return true;
373 }
374
375 errs << "expected elementwise op to apply to block arguments (modulo unary "
376 "casts)";
377 return false;
378}
379
380/// Returns true if the two operations are of the kinds specified by a pair of
381/// consecutive template arguments.
382template <typename AddOpTy, typename MulOpTy, typename... Args>
384 static_assert(sizeof...(Args) % 2 == 0,
385 "expected an even number of template arguments");
386 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
387 return true;
388
389 if constexpr (sizeof...(Args) > 0)
391 else
392 return false;
393}
394
395/// Returns true if the block is a body of a contraction with the kinds of
396/// operations given pairwise by template arguments.
397template <typename... Args>
401
402/// Given an `indexingMap` and its corresponding `iterators`, returns
403/// the positions of the iterators of type `iter` that are indexed by
404/// the `indexingMap` as a permutation. This is useful to infer various
405/// subcomputations on a `LinalgOp`. This is performed by looking up
406/// each result in the `indexingMap` and determining whether:
407/// - It is a single AffineDimExpr.
408/// - It is the only result involving this AffineDimExpr.
409static llvm::SmallDenseSet<int64_t>
412 utils::IteratorType iter) {
413 assert(iterators.size() == indexingMap.getNumDims());
414 llvm::SmallDenseSet<int64_t> res;
415 for (AffineExpr e : indexingMap.getResults()) {
416 if (auto d = dyn_cast<AffineDimExpr>(e)) {
417 if (iterators[d.getPosition()] == iter &&
418 llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
419 return e.isFunctionOfDim(d.getPosition());
420 }) == 1)
421 res.insert(d.getPosition());
422 }
423 }
424 return res;
425}
426
427namespace {
428auto par = utils::IteratorType::parallel;
429auto red = utils::IteratorType::reduction;
430} // namespace
431
432/// Infer the iterator types from the init affine map. This looks at which dims
433/// are present in the map results, and returns an iterator types array with
434/// parallel types for dims that are present, and reduction types for dims that
435/// are not present.
436static FailureOr<SmallVector<utils::IteratorType>>
438 if (!map.isProjectedPermutation())
439 return failure();
440 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
441 for (auto expr : map.getResults())
442 if (auto dim = dyn_cast<AffineDimExpr>(expr))
443 iterators[dim.getPosition()] = par;
444 return iterators;
445}
446
447/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
448/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
449/// 1. The m dimension is involved in an outer-product along LHS
450/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
451/// 2. The n dimension is involved in an outer-product along RHS
452/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
453/// 3. The k dimension appears as a permutation on LHS and RHS.
454/// 4. m, n and k appear only once in any given indexing.
455/// 5. Optional batch dimensions that appear in all operands are captured.
456/// This allows e.g. detecting that some contraction is embedded within
457/// `linalgOp` with some orthogonal heuristic.
458static FailureOr<ContractionDimensions>
461 llvm::SmallDenseSet<int64_t> a =
462 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
463 llvm::SmallDenseSet<int64_t> b =
464 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
465 llvm::SmallDenseSet<int64_t> c =
466 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
467
468 // A & C - B are the iterators involved in an outer-product along A (the LHS).
469 llvm::SmallDenseSet<int64_t> ac = a;
470 llvm::set_intersect(ac, c);
471 llvm::set_subtract(ac, b);
472 // B & C - A are the iterators involved in an outer-product along B (the RHS).
473 llvm::SmallDenseSet<int64_t> bc = b;
474 llvm::set_intersect(bc, c);
475 llvm::set_subtract(bc, a);
476 // A & B & C are the "batch" dimensions.
477 llvm::SmallDenseSet<int64_t> batches = a;
478 llvm::set_intersect(batches, b);
479 llvm::set_intersect(batches, c);
480
481 // A & B red are the reduction dimensions.
482 llvm::SmallDenseSet<int64_t> ra =
483 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
484 llvm::SmallDenseSet<int64_t> rb =
485 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
486 llvm::set_intersect(ra, rb);
487
488 // Return each set in sorted order.
489 ContractionDimensions dimensions{
490 SmallVector<unsigned, 2>(batches.begin(), batches.end()),
491 SmallVector<unsigned, 2>(ac.begin(), ac.end()),
492 SmallVector<unsigned, 2>(bc.begin(), bc.end()),
493 SmallVector<unsigned, 2>(ra.begin(), ra.end())};
494 llvm::sort(dimensions.batch);
495 llvm::sort(dimensions.m);
496 llvm::sort(dimensions.n);
497 llvm::sort(dimensions.k);
498 return dimensions;
499}
500
501FailureOr<ContractionDimensions>
503 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
504 return failure();
505 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
506 linalgOp.getIteratorTypesArray());
507}
508
509FailureOr<ContractionDimensions>
511 if (indexingMaps.size() != 3)
512 return failure();
513 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
514 if (failed(iterators))
515 return failure();
516 return inferContractionDimsImpl(indexingMaps, iterators.value());
517}
518
519namespace mlir::linalg::detail {
528} // namespace mlir::linalg::detail
529
533 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
534 if (!linalgOp)
536 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
538 auto mapRange = linalgOp.getIndexingMapsArray();
539 if (linalgOp.getNumReductionLoops() == 0)
541 if (llvm::any_of(mapRange,
542 [](AffineMap m) { return !m.isProjectedPermutation(); }))
544 // TODO: more fields than add/mul.
545 // clang-format off
547 arith::MulFOp, arith::AddFOp,
548 arith::MulIOp, arith::AddIOp,
549 complex::MulOp, complex::AddOp,
550 arith::AndIOp, arith::OrIOp>(
551 *linalgOp.getBlock())) {
553 }
554 // clang-format on
555
556 if (dimensions) {
557 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
558 assert(succeeded(res) && "unexpected failure to infer contraction dims");
559 *dimensions = *res;
560 }
562}
563
564StringRef
566 switch (res) {
568 return "expected a LinalgOp";
570 return "expected op with 2 inputs and 1 output";
572 return "expected at least 1 reduction";
574 return "expected indexing maps to be projected permutations";
576 return "expected add/mul op in the body";
578 return "";
579 }
580 llvm_unreachable("unhandled MatchContractionResult case");
581}
582
584 if (!linalgOp)
585 return false;
586 Operation *op = linalgOp.getOperation();
587 return isa<ContractionOpInterface>(op) ||
590}
591
592/// Verify that a LinalgOp `op` is a contraction.
593/// A Linalg contraction is defined in general terms:
594/// 1. Has 2 input and 1 output shapes.
595/// 2. Has at least one reduction dimension.
596/// 3. Has only projected permutation indexing maps.
597/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
598/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
599/// operations that may change the type (e.g. for mixed-precision).
600/// As a consequence, when vectorization of such an op occurs, the only special
601/// behavior is that the (unique) MulOpType is vectorized into a
602/// `vector.contract`. All other ops are handled in a generic fashion.
603/// In the future, we may wish to allow more input arguments and elementwise and
604/// constant operations that do not involve the reduction dimension(s).
611
612//===----------------------------------------------------------------------===//
613// ConvolutionOpInterface implementation
614//===----------------------------------------------------------------------===//
615
616/// Of the given two expressions returns one that is of type T (`lhs` gets
617/// preference over `rhs`)
618template <typename T>
620 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
621}
622
623namespace {
624/// Walk the indexing expressions for input of a convolution operation to verify
625/// its of the right form, either
626/// - AffineDimExpr
627/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
628/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
629///
630/// classifies the AffineDimExpr as convolved dimensions or unconvolved
631/// dimensions and verifies each dimension occurs only once.
632struct ConvAccessExprWalker
633 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
634 // Stores dimensions used in expressions of the above form.
635 llvm::SmallDenseSet<int64_t> convolvedDims;
636 // Stores the dual mapping between LHS and RHS of convolution exprs.
637 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
638 // Stores single use dimensions used by an AffineDimExpr.
639 llvm::SmallDenseSet<int64_t> unConvolvedDims;
640 // Stores a mapping from convolved dims to their coefficient.
641 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
642
643 // Removes dims with multiple uses in the source input map from dimension
644 // sets tracked by this walker.
645 void clearMultiUseDims(AffineMap map) {
646 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
647 if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
648 return e.isFunctionOfDim(dimPos);
649 }) > 1) {
650 convolvedDims.erase(dimPos);
651 unConvolvedDims.erase(dimPos);
652 // If a duplicate dim is marked as convolved, the pair of the duplicate
653 // dim must be removed from the map as well.
654 auto it = convolvedDimMapping.find(dimPos);
655 if (it != convolvedDimMapping.end()) {
656 int64_t pairedDim = it->second;
657 convolvedDims.erase(pairedDim);
658 unConvolvedDims.erase(pairedDim);
659 strideAndDilationMapping.erase(pairedDim);
660 convolvedDimMapping.erase(dimPos);
661 convolvedDimMapping.erase(pairedDim);
662 }
663 }
664 }
665 }
666
667 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
668 unsigned position = dimExpr.getPosition();
669 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
670 return failure();
671 }
672 unConvolvedDims.insert(position);
673 return success();
674 }
675
676 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
677
678 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
679
680 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
681 // In pre-order visit, top level op has to be an add op.
682 if (binaryExpr.getKind() != AffineExprKind::Add)
683 return failure();
684 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
685 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
686 if (failed(lhsDimPos) || failed(rhsDimPos))
687 return failure();
688 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
689 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
690 return success();
691 }
692
693 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
694 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
695 int64_t dim = dimExpr.getPosition();
696 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
697 return failure();
698 // Stride/dilation for this dim is implicitly 1.
699 strideAndDilationMapping[dim] =
701 convolvedDims.insert(dim);
702 return dim;
703 }
704 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
705 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
706 return failure();
707 auto lhsExpr = symbolMulExpr.getLHS();
708 auto rhsExpr = symbolMulExpr.getRHS();
709 // Check for symbol expression.
710 AffineExpr mulExpr =
712 // If there was no symbol expr, check for constant expression.
713 if (!mulExpr) {
714 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
715 }
716 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
717 if (!mulExpr || !dimExpr)
718 return failure();
719 int64_t dim = dimExpr.getPosition();
720 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
721 return failure();
722 strideAndDilationMapping[dim] = mulExpr;
723 convolvedDims.insert(dim);
724 return dim;
725 }
726 return failure();
727 }
728};
729} // namespace
730
731static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
732 assert(map.isProjectedPermutation() &&
733 "expected map to have projected permutations");
734 llvm::SmallDenseSet<int64_t> preservedDims;
735 for (auto expr : map.getResults())
736 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
737 return preservedDims;
738}
739
743 for (auto e : exprs) {
744 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
745 assert(constantExpr && "Found non-constant stride/dilation");
746 vals.push_back(constantExpr.getValue());
747 }
748 return vals;
749}
750
751/// Classifies dimensions in the `indexingMaps` used by a convolution
752/// subcomputation, as captured by `inputExprWalker`. If
753/// `allowEmptyConvolvedDims` is not set this will fail if there is not
754/// at least one convolved dimension pair (output image + filter loop).
755///
756/// The returned dimensions are ordered as follows:
757/// - `outputImage` is sorted by dimension index.
758/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
759/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
760/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
761/// - `strides[i]` corresponds to `outputImage[i]`.
762/// - `dilations[i]` corresponds to `filterLoop[i]`.
763/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
764///
765/// `nativeStrides` and `nativeDilations`, when non-null, are the op-carried
766/// `strides`/`dilations` attributes and take precedence over the values derived
767/// from the convolution access pattern. They are null for the maps-based
768/// overload.
769static FailureOr<ConvolutionDimensions> inferConvolutionDimsImpl(
771 ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims,
772 DenseIntElementsAttr nativeStrides, DenseIntElementsAttr nativeDilations) {
773 AffineMap filterMap = indexingMaps[1];
774 AffineMap outputMap = indexingMaps.back();
775 llvm::SmallDenseSet<int64_t> filterDims =
776 findPermutationsIndexingOperand(filterMap, iterators, par);
777 llvm::SmallDenseSet<int64_t> outputDims =
778 findPermutationsIndexingOperand(outputMap, iterators, par);
779
780 // unConvolvedDims & outputDims - filterDims are the batch iterators.
781 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
782 llvm::set_intersect(batch, outputDims);
783 llvm::set_subtract(batch, filterDims);
784
785 // convolvedDims & outputDims are the output image iterators.
786 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
787 llvm::set_intersect(oi, outputDims);
788
789 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
790 llvm::SmallDenseSet<int64_t> oc = filterDims;
791 llvm::set_intersect(oc, outputDims);
792 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
793
794 // filterDims & outputDims & unConvolvedDims are the depth iterators.
795 llvm::SmallDenseSet<int64_t> depth = filterDims;
796 llvm::set_intersect(depth, outputDims);
797 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
798
799 llvm::SmallDenseSet<int64_t> filterReducedDims =
800 findPermutationsIndexingOperand(filterMap, iterators, red);
801
802 // convolvedDims & filterReducedDims are the filter loop iterators.
803 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
804 llvm::set_intersect(fl, filterReducedDims);
805
806 // unConvolvedDims & filterReducedDims are the input channel iterators.
807 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
808 llvm::set_intersect(ic, filterReducedDims);
809
810 if (oi.empty() && !allowEmptyConvolvedDims)
811 return failure();
812
813 // Return each set in sorted order, with outputImage and filterLoop
814 // ordered so that outputImage[i] pairs with filterLoop[i].
815 ConvolutionDimensions dimensions{
816 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
817 SmallVector<unsigned, 2>(oi.begin(), oi.end()),
818 SmallVector<unsigned, 2>(oc.begin(), oc.end()),
819 /*filterLoop=*/SmallVector<unsigned, 2>{},
820 SmallVector<unsigned, 2>(ic.begin(), ic.end()),
821 SmallVector<unsigned, 2>(depth.begin(), depth.end()),
822 /*strides=*/SmallVector<int64_t, 2>{},
823 /*dilations=*/SmallVector<int64_t, 2>{}};
824 llvm::sort(dimensions.batch);
825 llvm::sort(dimensions.outputImage);
826 llvm::sort(dimensions.outputChannel);
827 llvm::sort(dimensions.inputChannel);
828 llvm::sort(dimensions.depth);
829 // Order filterLoop to match the pairing with outputImage. Each outputImage
830 // dimension has a corresponding filterLoop dimension from the convolution
831 // access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
832 // filterLoop[i].
833 for (unsigned oiDim : dimensions.outputImage)
834 dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
835
836 // Use the op carried strides/dilations attribute if present.
837 if (!nativeStrides) {
838 SmallVector<AffineExpr, 2> strideExprs;
839 for (unsigned oiDim : dimensions.outputImage)
840 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
841 dimensions.strides = getConstantsFromExprList(strideExprs);
842 } else {
843 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
844 }
845 if (!nativeDilations) {
846 SmallVector<AffineExpr, 2> dilationExprs;
847 for (unsigned flDim : dimensions.filterLoop)
848 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
849 dimensions.dilations = getConstantsFromExprList(dilationExprs);
850 } else {
851 dimensions.dilations =
852 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
853 }
854 return dimensions;
855}
856
857/// Find at least 1 parallel (output_image) and reduction (filter_loop)
858/// dimension candidates that form a convolution subcomputation within
859/// `linalgOp`. The LHS is assumed to be the convolution input while the
860/// RHS is assumed as the filter.
861/// These dimensions are such that:
862/// 1. Optional batch dimensions that appear in the input and filter.
863/// 2. The output_image dimension is involved in a cross-correlation along LHS
864/// (i.e. it is a permutation on RES and LHS and has an associated
865/// filter_loop in RHS).
866/// 3. Optional output_channel dimension is involved in an outer-product along
867/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
868/// LHS).
869/// 4. Optional input_channel dimension appears as a permutation on LHS and
870/// RHS.
871/// 5. The filter_loop dimension appears as a permutation on the RHS and
872/// represents the shape of the kernel cross-correlated along a
873/// corresponding output_image dim.
874/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
875/// 7. All dimensions appear only once in any given indexing map.
876/// This allows e.g. detecting that some convolution is embedded within
877/// `linalgOp` with some orthogonal heuristic.
878///
879/// The `outputImage` and `filterLoop` arrays are ordered such that
880/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
881/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
882/// dimension 2). Other dimension sets are returned in sorted order.
883///
884/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
885FailureOr<ConvolutionDimensions>
887 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
888 return failure();
889
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
891
892 // Check the input indexing map has the right form.
893 ConvAccessExprWalker inputExprWalker;
894 for (AffineExpr expr : indexingMaps[0].getResults())
895 (void)inputExprWalker.visit(expr);
896 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
897
899 indexingMaps, linalgOp.getIteratorTypesArray(), inputExprWalker,
900 /*allowEmptyConvolvedDims=*/false,
901 linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"),
902 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations"));
903}
904
905FailureOr<ConvolutionDimensions>
907 if (indexingMaps.size() != 3)
908 return failure();
909
910 // Infer iterator types from the output map.
911 FailureOr<SmallVector<utils::IteratorType>> iterators =
912 inferIteratorsFromOutMap(indexingMaps[2]);
913 if (failed(iterators))
914 return failure();
915
916 // Check the input indexing map has the right form.
917 ConvAccessExprWalker inputExprWalker;
918 for (AffineExpr expr : indexingMaps[0].getResults())
919 (void)inputExprWalker.visit(expr);
920 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
921
922 return inferConvolutionDimsImpl(indexingMaps, iterators.value(),
923 inputExprWalker,
924 /*allowEmptyConvolvedDims=*/false,
925 /*nativeStrides=*/nullptr,
926 /*nativeDilations=*/nullptr);
927}
928
929namespace mlir::linalg::detail {
941} // namespace mlir::linalg::detail
942
945 Operation *op, ConvolutionDimensions *dimensions,
946 bool allowEmptyConvolvedDims) {
947 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
948 if (!linalgOp)
950 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
952
953 auto indexingMaps = linalgOp.getIndexingMapsArray();
954
955 // Check the input indexing map has the right form.
956 ConvAccessExprWalker inputExprWalker;
957 if (llvm::any_of(indexingMaps[0].getResults(),
958 [&inputExprWalker](AffineExpr expr) {
959 return failed(inputExprWalker.visit(expr));
960 })) {
962 }
963
964 // Filter and output maps must be projected permutation.
965 if (!indexingMaps[1].isProjectedPermutation() ||
966 !indexingMaps.back().isProjectedPermutation())
968
969 auto iteratorTypes = linalgOp.getIteratorTypesArray();
970
971 llvm::SmallDenseSet<int64_t> outputDims =
972 getPreservedDims(indexingMaps.back());
973 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
974 // Make sure all loops are characterized as one of:
975 // - Batch loop : present in output, as non-convolved in input, not present in
976 // filter.
977 // - Output image dimension : present in output, convolved dims in input, not
978 // present in filter.
979 // - Output channel dimension : present in output, not present in input,
980 // present in filter.
981 // - Filter loop dimension : present in filter, convolved in input, not
982 // present in output.
983 // - Input channel dimension : unconvolved in input, not present in output,
984 // present in filter.
985 // - Depth multiplier : unconvolved in input, present in output, present in
986 // filter.
987 llvm::SmallDenseSet<int64_t> allLoopDims;
988 for (auto outputExpr : indexingMaps.back().getResults()) {
989 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
990 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
991 !filterDims.count(outputDim)) {
992 // Batch dimension.
993 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
995 allLoopDims.insert(outputDim);
996 continue;
997 }
998 if (inputExprWalker.convolvedDims.count(outputDim) &&
999 !filterDims.count(outputDim)) {
1000 // Output image Loop dimension.
1001 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1003 allLoopDims.insert(outputDim);
1004 continue;
1005 }
1006 if (!inputExprWalker.convolvedDims.count(outputDim) &&
1007 !inputExprWalker.unConvolvedDims.count(outputDim) &&
1008 filterDims.count(outputDim)) {
1009 // Output channel dimension.
1010 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1012 allLoopDims.insert(outputDim);
1013 continue;
1014 }
1015 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
1016 filterDims.count(outputDim)) {
1017 // Depth multiplier.
1018 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1020 allLoopDims.insert(outputDim);
1021 continue;
1022 }
1024 }
1025 for (auto filterExpr : indexingMaps[1].getResults()) {
1026 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1027 if (outputDims.count(filterDim) &&
1028 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1029 !inputExprWalker.convolvedDims.count(filterDim)) {
1030 // Output channel dimension. This is already seen, continue;
1031 continue;
1032 }
1033 if (inputExprWalker.convolvedDims.count(filterDim) &&
1034 !outputDims.count(filterDim)) {
1035 // Filter loop dimension.
1036 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1038 if (allLoopDims.count(filterDim))
1040 allLoopDims.insert(filterDim);
1041 continue;
1042 }
1043 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1044 !outputDims.count(filterDim)) {
1045 // Input channel dimension.
1046 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1048 if (allLoopDims.count(filterDim))
1050 allLoopDims.insert(filterDim);
1051 continue;
1052 }
1053 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1054 outputDims.count(filterDim)) {
1055 // Depthwise loop. Already seen.
1056 continue;
1057 }
1059 }
1060 // All loops must be covered now.
1061 if (allLoopDims.size() != linalgOp.getNumLoops())
1063
1064 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1066
1067 if (dimensions) {
1068 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1069 indexingMaps, iteratorTypes, inputExprWalker, allowEmptyConvolvedDims,
1070 linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"),
1071 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations"));
1072 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1073 *dimensions = *res;
1074 }
1075
1077}
1078
1079StringRef
1081 switch (res) {
1083 return "expected a LinalgOp";
1085 return "expected op with 2 inputs and 1 output";
1087 return "unexpected input index map for convolutions";
1089 return "expected output/filter indexing maps to be projected permutations";
1091 return "unexpected loop dimension for convolution op";
1093 return "expected all iterators used to access outputs to be parallel";
1095 return "expected all iterators not used to access outputs to be reduction";
1097 return "expected convolved dim to be non-empty";
1099 return "";
1100 }
1101 llvm_unreachable("unhandled MatchConvolutionResult case");
1102}
1103
1105 bool allowEmptyConvolvedDims) {
1107 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1109}
1110
1117
1118//===----------------------------------------------------------------------===//
1119// FillOpInterface implementation
1120//===----------------------------------------------------------------------===//
1121
1122namespace {
1123enum class MatchFillResult {
1124 Success = 0,
1125 NotLinalgOp,
1126 WrongNumOperands,
1127 NotScalarInput,
1128 TypeMismatch
1129};
1130} // namespace
1131
1132static MatchFillResult isFillInterfaceImpl(Operation *op) {
1133 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1134 if (!linalgOp)
1135 return MatchFillResult::NotLinalgOp;
1136 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1137 return MatchFillResult::WrongNumOperands;
1138
1139 OpOperand *value = linalgOp.getDpsInputOperand(0);
1140 if (!linalgOp.isScalar(value))
1141 return MatchFillResult::NotScalarInput;
1142
1143 // Check that the scalar input type matches the output element type.
1144 OpOperand *output = linalgOp.getDpsInitOperand(0);
1145 Type scalarType = value->get().getType();
1146 Type outputElementType = getElementTypeOrSelf(output->get().getType());
1147 if (scalarType != outputElementType)
1148 return MatchFillResult::TypeMismatch;
1149
1150 return MatchFillResult::Success;
1151}
1152
1154 MatchFillResult res = isFillInterfaceImpl(op);
1155 if (res == MatchFillResult::NotLinalgOp)
1156 return op->emitError("expected a LinalgOp");
1157 if (res == MatchFillResult::WrongNumOperands)
1158 return op->emitError("expected op with 1 input and 1 output");
1159 if (res == MatchFillResult::NotScalarInput)
1160 return op->emitError("expected op with scalar input");
1161 if (res == MatchFillResult::TypeMismatch) {
1162 auto linalgOp = cast<linalg::LinalgOp>(op);
1163 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1164 Type outputElementType =
1165 getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
1166 return op->emitOpError("expected fill value type (")
1167 << scalarType << ") to match output element type ("
1168 << outputElementType << ")";
1169 }
1170
1171 return success();
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// StructuredOpInterface implementation
1176//===----------------------------------------------------------------------===//
1177
1178SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1179 Location loc) {
1181 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1182 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1183 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1184 }
1185 return res;
1186}
1187
1188SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1190 assert(!hasDynamicShape() && "expected operands to have static shapes");
1191 for (OpOperand &opOperand : getOperation()->getOpOperands())
1192 llvm::append_range(res, getShape(&opOperand));
1193 return res;
1194}
1195
1196SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1197 AffineMap map = getLoopsToShapesMap();
1198 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1199 auto viewSizes = createFlatListOfOperandDims(b, loc);
1200 SmallVector<Range, 4> res(numDims);
1201 for (unsigned idx = 0; idx < numRes; ++idx) {
1202 auto result = map.getResult(idx);
1203 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1204 if (res[d.getPosition()].offset)
1205 continue;
1206 res[d.getPosition()] =
1207 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1208 }
1209 }
1210 return res;
1211}
1212
1213/// Visitor to check if any of the given set of positions from AffineDimExprs
1214/// are used within an AffineExpr.
1216 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1217 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1218 : positions(std::move(positions)) {}
1219
1221 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1222 }
1223
1225 return positions.test(dimExpr.getPosition());
1226 }
1227
1228 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1229
1230 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1231
1232private:
1233 llvm::SmallBitVector positions;
1234};
1235
1236static std::pair<int64_t, int64_t>
1238 int64_t inputRankSum = 0;
1239 int64_t outputRankSum = 0;
1240 for (OpOperand *input : op.getDpsInputOperands())
1241 inputRankSum += op.getRank(input);
1242 for (OpOperand &output : op.getDpsInitsMutable())
1243 outputRankSum += op.getRank(&output);
1244 return {inputRankSum, inputRankSum + outputRankSum};
1245}
1246
1247LogicalResult
1248LinalgOp::reifyResultShapes(OpBuilder &b,
1249 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1250 // An example that helps understand the logic below.
1251 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1252 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1253 // This is achieved as follows.
1254 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1255 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1256 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1257 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1258 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1259 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1260
1261 // Find the position in the above map that represents the shape of the
1262 // result:dim being inferred.
1263 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1264
1265 /// From loopsToShapesMap extract the submap that represents the shape of the
1266 /// (resultIdx, dim) needed.
1267 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1268 resultShapesSubMapPos.first,
1269 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1270 AffineMap resultShapesFromInputShapesMap =
1271 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1272
1273 // Check that the result dim map does not contain the positions corresponding
1274 // to the outputs.
1275 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1276 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1277 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1278 Location loc = getOperation()->getLoc();
1279 IRRewriter rewriter(b);
1280 SmallVector<OpFoldResult> allResultDimValues =
1282 rewriter, loc, resultShapesFromInputShapesMap,
1283 createFlatListOfOperandDims(b, loc));
1284 int64_t pos = 0;
1285 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1286 for (OpOperand &opOperand : getDpsInitsMutable()) {
1287 SmallVector<OpFoldResult> shapes;
1288 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1289 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1290 if (!shapedType.isDynamicDim(dim)) {
1291 // Static dim: Return IntegerAttr.
1292 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1293 } else {
1294 // Dynamic dim: Return Value.
1295 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1296 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1297 : allResultDimValues[pos];
1298 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1299 }
1300 pos++;
1301 }
1302 reifiedReturnShapes.emplace_back(std::move(shapes));
1303 }
1304 return success();
1305}
1306
1307/// Return the index in the indexingMaps vector that corresponds to this
1308/// `opOperand`.
1309int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1310 auto operandNumber = opOperand->getOperandNumber();
1311 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1312 if (!dpsIface.isDpsInput(opOperand))
1313 return operandNumber;
1314 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1315 assert(!dpsIface.isDpsInit(opOperand));
1316 // Account for potential inputs that are not DPS and may not appear in
1317 // `indexingMaps`.
1318 return cast<DestinationStyleOpInterface>(*this->getOperation())
1319 .getNumDpsInputs() +
1320 operandNumber - start;
1321}
1322
1324 LinalgOp linalgOp = cast<LinalgOp>(op);
1325 // Mixed tensor/buffer operands are not allowed.
1326 if (!linalgOp.hasPureTensorSemantics() &&
1327 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1328 return op->emitOpError("expected to have pure tensor or buffer semantics");
1329
1330 // Before checking indexing maps, we need to make sure the attributes
1331 // referenced by it are valid.
1332 if (linalgOp.hasDynamicIndexingMaps())
1333 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1334 return failure();
1335
1336 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1337 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1338 return failure();
1339
1340 // Set this flag if this op has user defined maps. This is required to guard
1341 // the below error condition which assume default indexing maps.
1342 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1343 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1344 // Domain must be consistent.
1345 unsigned numLoops = linalgOp.getNumLoops();
1346 if (indexingMap.getNumDims() != numLoops)
1347 return op->emitOpError("expected indexing_map #")
1348 << opOperand.getOperandNumber() << " to have " << numLoops
1349 << " dim(s) to match the number of loops";
1350 }
1351 SmallVector<unsigned> redDims;
1352 linalgOp.getReductionDims(redDims);
1353
1354 if (!linalgOp.getShapesToLoopsMap())
1355 return op->emitOpError("expected the shape-to-loops map to be non-null");
1356
1357 // Check the region has exactly one block.
1358 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1359 return op->emitOpError("expects to have 1 region with 1 block");
1360
1361 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1362 // types.
1363 // TODO: once ranked shape types are plugged in, we may want to drop the
1364 // corresponding bbargs, that can never be read from. This will be subject to
1365 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1366 // not used).
1367 Block &block = linalgOp->getRegion(0).front();
1368
1369 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1370 return op->emitOpError("expected as many non-induction variable region "
1371 "arguments as the number of input/output operands");
1372
1373 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1374 Type elementType = opOperand->get().getType();
1375 if (isa<MemRefType, RankedTensorType>(elementType))
1376 elementType = getElementTypeOrSelf(opOperand->get().getType());
1377 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1378 if (elementType != argType)
1379 return op->emitOpError("expected type of bb argument #")
1380 << opOperand->getOperandNumber() << " (" << argType << ")"
1381 << " to match element or self type of the corresponding operand ("
1382 << elementType << ")";
1383 }
1384
1385 return success();
1386}
return success()
lhs
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims, DenseIntElementsAttr nativeStrides, DenseIntElementsAttr nativeDilations)
Classifies dimensions in the indexingMaps used by a convolution subcomputation, as captured by inputE...
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define add(a, b)
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:782
unsigned getNumOperands()
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop