MLIR 23.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/SmallVectorExtras.h"
30#include <optional>
31#include <utility>
32
33namespace mlir {
34#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::linalg;
40
41//===---------------------------------------------------------------------===//
42// Methods and patterns that fuse elementwise `linalg.generic` operations.
43//===---------------------------------------------------------------------===//
44
45/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
46/// the `producer` to use in the fused operation given the indexing map of the
47/// result of the producer in the consumer.
49 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
50 AffineMap fusedConsumerArgIndexMap) {
51 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
52 // from consumer loop -> consumer arg tensor index/producer result tensor
53 // index. The fused loop is same as the consumer loop. For each producer arg
54 // the indexing map to be computed is a map from consumer loop -> producer
55 // arg tensor index.
56 // producerResultIndexMap is a map from producer loop -> tensor index.
57 // Compute the inverse to get map from tensor index -> producer loop.
58 // The inverse is a map from producer result tensor index -> producer loop.
59 AffineMap invProducerResultIndexMap =
60 inversePermutation(producerResultIndexMap);
61 assert(invProducerResultIndexMap &&
62 "expected producer result indexing map to be invertible");
63
64 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
65 // argMap is a map from producer loop -> producer arg tensor index.
66 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
67
68 // Compose argMap with invProducerResultIndexMap to get a map from
69 // producer result tensor index -> producer arg tensor index.
70 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
71
72 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
73 // consumer loop/ fused loop -> producer arg tensor index.
74 return t1.compose(fusedConsumerArgIndexMap);
75}
76
77// Checks if the given operand can be dropped, and the remaining operands
78// of the fused producer & consumer after the fusion can still compute the
79// bounds of the op.
81 GenericOp producer, GenericOp consumer,
82 ArrayRef<OpOperand *> opOperandsToIgnore) {
83 SmallVector<AffineMap> indexingMaps;
84
85 SmallVector<GenericOp> ops = {producer, consumer};
86 for (auto &op : ops) {
87 for (auto &opOperand : op->getOpOperands()) {
88 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
89 continue;
90 }
91 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
92 }
93 }
94 if (indexingMaps.empty()) {
95 // If there are no indexing maps, the operand can only be dropped
96 // if neither op has loops.
97 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
98 }
99
100 // The concatanation of the remained indexing maps must be invertible, so
101 // the bounds of the op can be still computed after dropping the selected
102 // operand. inversePermutation returns an empty AffineMap in case the
103 // concatanated indexing maps are not invertible.
105 indexingMaps, producer.getContext())) != AffineMap();
106}
107
108/// Returns a set of indices of the producer's results which would
109/// be preserved after the fusion.
110/// * There is a chance that the implementation of the transformation does not
111/// agree with the result of this method. This function gives a prediction based
112/// on an optimized fusion.
114 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
115 llvm::SmallDenseSet<int> preservedProducerResults;
116 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
117
118 // The fusedOperand will be removed during the fusion
119 opOperandsToIgnore.emplace_back(fusedOperand);
120
121 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
122 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
123 opOperandsToIgnore.emplace_back(outputOperand);
124 if (producer.payloadUsesValueFromOperand(outputOperand) ||
126 opOperandsToIgnore) ||
127 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
128 return user != consumer.getOperation();
129 })) {
130 preservedProducerResults.insert(producerResult.index());
131
132 // In case the operand can't be dropped
133 (void)opOperandsToIgnore.pop_back_val();
134 }
135 }
136 return preservedProducerResults;
137}
138
139/// Conditions for elementwise fusion of generic operations.
141 if (!fusedOperand)
142 return false;
143
144 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
145 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
146
147 // Check producer and consumer are generic ops.
148 if (!producer || !consumer)
149 return false;
150
151 // Consumer can have mixed semantics, just check operand itself has tensor
152 // type. Producer must have full tensor semantics to avoid potential
153 // aliasing between producer and consumer memrefs.
154 if (!producer.hasPureTensorSemantics() ||
155 !isa<RankedTensorType>(fusedOperand->get().getType()))
156 return false;
157
158 // Verify that
159 // - the producer has all "parallel" iterator type.
160 if (producer.getNumParallelLoops() != producer.getNumLoops())
161 return false;
162
163 // Only allow fusing the producer of an input operand for now.
164 // TODO: allow fusing the producer of an output operand.
165 if (!consumer.isDpsInput(fusedOperand))
166 return false;
167
168 // Get the consumer index map. The number of results of the consumer index
169 // map must match the number of loops of the producer.
170 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
171 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
172 return false;
173
174 // Finally the index_map for the result must be invertible. For now just
175 // verify it is a permutation.
176 auto producerResult = cast<OpResult>(fusedOperand->get());
177 AffineMap producerResultIndexMap =
178 producer.getIndexingMapMatchingResult(producerResult);
179 if (!producerResultIndexMap.isPermutation())
180 return false;
181
182 // Ensure that the fusion does not remove size information required to
183 // get the loop bounds. For non-reduction generics, this is trivially the
184 // case due to the output operand. For reductions, we need to check that after
185 // the fusion, each loop dimension has at least one input that defines it.
186 if ((consumer.getNumReductionLoops())) {
187 BitVector coveredDims(consumer.getNumLoops(), false);
188
189 auto addToCoveredDims = [&](AffineMap map) {
190 for (auto result : map.getResults())
191 if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
192 coveredDims[dimExpr.getPosition()] = true;
193 };
194
195 for (auto pair :
196 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
197 Value operand = std::get<0>(pair);
198 if (operand == fusedOperand->get())
199 continue;
200 AffineMap operandMap = std::get<1>(pair);
201 addToCoveredDims(operandMap);
202 }
203
204 for (OpOperand *operand : producer.getDpsInputOperands()) {
205 AffineMap newIndexingMap =
207 operand, producerResultIndexMap, consumerIndexMap);
208 addToCoveredDims(newIndexingMap);
209 }
210 if (!coveredDims.all())
211 return false;
212 }
213
214 return true;
215}
216
217/// Generate the region of the fused tensor operation. The region of the fused
218/// op must be empty.
220 RewriterBase &rewriter, GenericOp fusedOp,
221 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
222 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
224 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
225 // Build the region of the fused op.
226 Block &producerBlock = producer->getRegion(0).front();
227 Block &consumerBlock = consumer->getRegion(0).front();
228 OpBuilder::InsertionGuard guard(rewriter);
229 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
230 IRMapping mapper;
231
232 // 2. Add an index operation for every fused loop dimension and use the
233 // `consumerToProducerLoopsMap` to map the producer indices.
234 if (producer.hasIndexSemantics()) {
235 // Add an index operation for every fused loop dimension.
236 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 SmallVector<Value> fusedIndices;
238 fusedIndices.reserve(numFusedOpLoops);
239 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
240 std::back_inserter(fusedIndices), [&](uint64_t dim) {
241 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 });
243 for (IndexOp indexOp :
244 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
245 Value newIndex = affine::AffineApplyOp::create(
246 rewriter, producer.getLoc(),
247 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
248 mapper.map(indexOp.getResult(), newIndex);
249 }
250 }
251 // TODO: allow fusing the producer of an output operand.
252 assert(consumer.isDpsInput(fusedOperand) &&
253 "expected producer of input operand");
254 // 3. Consumer input operands up to consumerIdx (exclusive).
255 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
256 fusedOperand->getOperandNumber())) // input assumption.
257 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
258
259 // Replacing consumerIdx requires getting the cloned, yielded, value from
260 // the (cloned) producer block. This happens in step 9.
261
262 // 4. Splice in producer's input operands.
263 for (BlockArgument bbArg :
264 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
265 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
266
267 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
268 for (BlockArgument bbArg :
269 consumerBlock.getArguments()
270 .take_front(consumer.getNumDpsInputs())
271 .drop_front(fusedOperand->getOperandNumber() + 1))
272 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
273
274 // 6. All of the producer's output operands
275 for (const auto &bbArg : llvm::enumerate(
276 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
277 if (!preservedProducerResults.count(bbArg.index()))
278 continue;
279 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
280 bbArg.value().getLoc()));
281 }
282
283 // 7. All of consumer's output operands.
284 for (BlockArgument bbArg :
285 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
286 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
287
288 // 8. Clone all producer operations except for the yield and index operations
289 // to the fused operation.
290 for (auto &op : producerBlock.without_terminator()) {
291 if (!isa<IndexOp>(op))
292 rewriter.clone(op, mapper);
293 }
294 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
295 // forward the yield operand.
296 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
297 unsigned producerResultNumber =
298 cast<OpResult>(fusedOperand->get()).getResultNumber();
300 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
301
302 // Sanity checks, if replacement is not already in the mapper then it must be
303 // produced outside.
304 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
305 if (auto bb = dyn_cast<BlockArgument>(replacement))
306 assert(bb.getOwner() != &producerBlock &&
307 "yielded block argument must have been mapped");
308 else
309 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
310 "yielded value must have been mapped");
311 }
312 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
314 // 10. Clone operations from the consumer to the fused op.
315 for (auto &op : consumerBlock.without_terminator())
316 rewriter.clone(op, mapper);
317
318 // 11. Include the final yield (which is the remapped values for all the
319 // yield)
320 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
321 SmallVector<Value> fusedYieldValues;
322 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
323 consumerYieldOp.getNumOperands());
324 for (const auto &producerYieldVal :
325 llvm::enumerate(producerYieldOp.getOperands())) {
326 if (preservedProducerResults.count(producerYieldVal.index()))
327 fusedYieldValues.push_back(
328 mapper.lookupOrDefault(producerYieldVal.value()));
329 }
330 for (auto consumerYieldVal : consumerYieldOp.getOperands())
331 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
332 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
333
334 // Sanity checks.
335 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
336 "Ill-formed GenericOp region");
337}
338
339FailureOr<mlir::linalg::ElementwiseOpFusionResult>
341 OpOperand *fusedOperand) {
342 assert(areElementwiseOpsFusable(fusedOperand) &&
343 "expected elementwise operation pre-conditions to pass");
344 auto producerResult = cast<OpResult>(fusedOperand->get());
345 auto producer = cast<GenericOp>(producerResult.getOwner());
346 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
347 // TODO: allow fusing the producer of an output operand.
348 assert(consumer.isDpsInput(fusedOperand) &&
349 "expected producer of input operand");
350 /// Find the results of the producer that have uses outside of the consumer,
351 /// after the fusion.
352 llvm::SmallDenseSet<int> preservedProducerResults =
354 fusedOperand);
355
356 // Compute the fused operands list and indexing maps.
357 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
358 SmallVector<Type> fusedResultTypes;
359 SmallVector<AffineMap> fusedIndexMaps;
360 fusedInputOperands.reserve(producer.getNumDpsInputs() +
361 consumer.getNumDpsInputs());
362 fusedOutputOperands.reserve(preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedResultTypes.reserve(preservedProducerResults.size() +
365 consumer.getNumDpsInits());
366 fusedIndexMaps.reserve(producer->getNumOperands() +
367 consumer->getNumOperands());
368 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
369 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
370 auto consumerInputs = consumer.getDpsInputOperands();
371 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
372 return operand == fusedOperand;
373 });
374 assert(it != consumerInputs.end() && "expected to find the consumer operand");
375 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
376 fusedInputOperands.push_back(opOperand->get());
377 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
378 }
379 // 4. Splice in producer's input operands/maps.
380 AffineMap producerResultIndexMap =
381 producer.getIndexingMapMatchingResult(producerResult);
382 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
383 fusedInputOperands.push_back(opOperand->get());
384 // Compute indexing maps for the producer args in the fused operation.
386 opOperand, producerResultIndexMap,
387 consumer.getMatchingIndexingMap(fusedOperand));
388 fusedIndexMaps.push_back(map);
389 }
390 // 5. Remaining consumer's input operands/maps (drop past index
391 // `consumerIdx`).
392 for (OpOperand *opOperand :
393 llvm::make_range(std::next(it), consumerInputs.end())) {
394 fusedInputOperands.push_back(opOperand->get());
395 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
396 }
397
398 // 6. Collect all of the producer outputs.
399 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
400 if (!preservedProducerResults.count(opOperand.index()))
401 continue;
402
403 fusedOutputOperands.push_back(opOperand.value().get());
405 &opOperand.value(), producerResultIndexMap,
406 consumer.getMatchingIndexingMap(fusedOperand));
407 fusedIndexMaps.push_back(map);
408 fusedResultTypes.push_back(opOperand.value().get().getType());
409 }
410
411 // 7. All of consumer's output operands (skip operands: added by the builder).
412 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
413 fusedOutputOperands.push_back(opOperand.get());
414 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
415 Type resultType = opOperand.get().getType();
416 if (!isa<MemRefType>(resultType))
417 fusedResultTypes.push_back(resultType);
418 }
419
420 // Generate the fused op.
421 auto fusedOp = GenericOp::create(
422 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
424 consumer.getIteratorTypes(),
425 /*doc=*/nullptr,
426 /*library_call=*/nullptr);
427 if (!fusedOp.getShapesToLoopsMap()) {
428 // Fused op has invalid indexing maps. Typically this means something is off
429 // in the input, but going ahead here would result in verification errors.
430 // So cleanup and abort.
431 rewriter.eraseOp(fusedOp);
432 return rewriter.notifyMatchFailure(
433 fusedOp, "fused op failed loop bound computation check");
434 }
435
436 // Construct an AffineMap from consumer loops to producer loops.
437 // consumer loop -> tensor index
438 AffineMap consumerResultIndexMap =
439 consumer.getMatchingIndexingMap(fusedOperand);
440 // tensor index -> producer loop
441 AffineMap invProducerResultIndexMap =
442 inversePermutation(producerResultIndexMap);
443 assert(invProducerResultIndexMap &&
444 "expected producer result indexig map to be invertible");
445 // consumer loop -> producer loop
446 AffineMap consumerToProducerLoopsMap =
447 invProducerResultIndexMap.compose(consumerResultIndexMap);
448
450 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
451 consumer.getNumLoops(), preservedProducerResults);
453 result.fusedOp = fusedOp;
454 int resultNum = 0;
455 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
456 if (preservedProducerResults.count(index))
457 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
458 for (auto consumerResult : consumer->getResults())
459 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
460 return result;
461}
462
463namespace {
464/// Patterns to fuse a generic op, with the producer of its operands.
465class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
466public:
467 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
468 PatternBenefit benefit = 1)
469 : OpRewritePattern<GenericOp>(context, benefit),
470 controlFn(std::move(fun)) {}
471
472 LogicalResult matchAndRewrite(GenericOp genericOp,
473 PatternRewriter &rewriter) const override {
474 // Find the first operand that is defined by another generic op on tensors.
475 for (OpOperand &opOperand : genericOp->getOpOperands()) {
476 if (!areElementwiseOpsFusable(&opOperand))
477 continue;
478 if (!controlFn(&opOperand))
479 continue;
480
481 Operation *producer = opOperand.get().getDefiningOp();
482
483 // Find the producer of the operand.
484 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 fuseElementwiseOps(rewriter, &opOperand);
486 if (failed(fusionResult))
487 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
488
489 // Perform the fusion.
490 for (auto [origVal, replacement] : fusionResult->replacements) {
491 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
492 // Only replace consumer uses.
493 return use.get().getDefiningOp() != producer;
494 });
495 }
496 rewriter.eraseOp(genericOp);
497 return success();
498 }
499 return failure();
500 }
501
502private:
503 ControlFusionFn controlFn;
504};
505} // namespace
506
507//===---------------------------------------------------------------------===//
508// Methods and patterns that fuse reshape ops with elementwise operations by
509// expanding the dimensionality of the elementwise operations.
510//===---------------------------------------------------------------------===//
511
512/// Conditions for folding a structured linalg operation with a reshape op by
513/// expanding the iteration space dimensionality for tensor operations. These
514/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
515/// the following fusion pattern.
516///
517/// Consider
518///
519/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
520/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
521/// affine_map<(d0, d1, d2) -> (d1, d2)>,
522/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
523/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
524/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
525///
526/// The reshape can be folded into the `linalgOp` if its loop dimensionality
527/// is increased to match the result (operand) of the tensor.expand_shape.
528/// The indexing_map of the fused tensor in the `linalgOp` and the
529/// reassociation map helps compute the indexing maps of the modified op.
530/// For the above example, based on the reassociation map it
531/// can be concluded that
532///
533/// - The loop used to access the first dimension of the fused tensor is split
534/// into two.
535/// - The loop used to access the second dimension of the fused tensor is kept
536/// as is.
537/// - The loop used to access the third dimension of the fused tensor is split
538/// into three.
539///
540/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
541/// op, then
542///
543/// d0 -> e0, e1
544/// d1 -> e2, e3, e4
545/// d2 -> e5
546///
547/// substituting this, the structured op can be rewritten as
548///
549/// %d = linalg.generic ins(%0, %1 : )
550/// indexing_maps =
551/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
553/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
554///
555/// Since operands to the linalg generic are now 5D, reshapes can be introduced
556/// to make it consistent
557///
558/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
559/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
560/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
561/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
562///
563/// The added reshapes are again expanding patterns, so they will get fused
564/// with its producers if possible.
565static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
566 OpOperand *fusableOpOperand) {
567 // Is fusable only if:
568 // - All the indexing maps for operands and results are projected
569 // permutations.
570 // - The fused tensor is not a scalar.
572 linalgOp.getIteratorTypesArray();
573 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
574 return linalgOp.hasPureTensorSemantics() &&
575 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 [](Attribute attr) {
577 return cast<AffineMapAttr>(attr)
578 .getValue()
579 .isProjectedPermutation();
580 }) &&
581 operandMap.getNumResults() > 0;
582}
583
584namespace {
585/// Information needed to expand a generic operation to fold the reshape with
586/// it.
587class ExpansionInfo {
588public:
589 // Computes the mapping from original dimensions of the op to the dimensions
590 // of the expanded op given the `indexingMap` of the fused operand/result of
591 // the generic op, the `reassocationMaps` of the reshape op and the shape of
592 // the expanded op.
593 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
594 ArrayRef<AffineMap> reassociationMaps,
595 ArrayRef<OpFoldResult> expandedShape,
596 PatternRewriter &rewriter);
597 unsigned getOrigOpNumDims() const { return reassociation.size(); }
598 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
599 ReassociationIndicesRef getExpandedDims(unsigned i) const {
600 return reassociation[i];
601 }
602 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
603 return expandedShapeMap[i];
604 }
605 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
606
607private:
608 /// Reassociation from the dimensions in the original operation to the
609 /// dimension of the expanded operation.
610 SmallVector<ReassociationIndices> reassociation;
611 /// Mapping from extent of loops in the original operation, to the extent of
612 /// loops in the expanded operation.
613 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
614 /// Extent of the loop in the original operation.
615 SmallVector<OpFoldResult> originalLoopExtent;
616 unsigned expandedOpNumDims;
617};
618} // namespace
619
620LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
621 OpOperand *fusableOpOperand,
622 ArrayRef<AffineMap> reassociationMaps,
623 ArrayRef<OpFoldResult> expandedShape,
624 PatternRewriter &rewriter) {
625 if (reassociationMaps.empty())
626 return failure();
627 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628
629 OpBuilder::InsertionGuard g(rewriter);
630 rewriter.setInsertionPoint(linalgOp);
631 originalLoopExtent = llvm::map_to_vector(
632 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
633 [](Range r) { return r.size; });
634
635 reassociation.clear();
636 expandedShapeMap.clear();
637 // Compute the number of dimension in the expanded op that correspond to each
638 // dimension of the original op.
639 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
640 expandedShapeMap.resize(fusedIndexMap.getNumDims());
641 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
642 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
643 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
644 numExpandedDims[pos] = foldedDims.getNumResults();
645 ArrayRef<OpFoldResult> shape =
646 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
647 expandedShapeMap[pos].assign(shape.begin(), shape.end());
648 }
649 // The remaining dimensions remain the same.
650 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
651 if (expandedShapeMap[i].empty())
652 expandedShapeMap[i] = {originalLoopExtent[i]};
653
654 // Compute reassociation map from the original op to the expanded op.
655 unsigned sum = 0;
656 reassociation.reserve(fusedIndexMap.getNumDims());
657 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
658 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
659 reassociation.emplace_back(seq.begin(), seq.end());
660 sum += numFoldedDim.value();
661 }
662 expandedOpNumDims = sum;
663 return success();
664}
665
666/// Return the indexing map to use in the expanded op for a given the
667/// `indexingMap` of the original operation.
668static AffineMap
670 const ExpansionInfo &expansionInfo) {
672 for (AffineExpr expr : indexingMap.getResults()) {
673 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 SmallVector<AffineExpr, 4> expandedExprs = llvm::map_to_vector<4>(
675 expansionInfo.getExpandedDims(pos), [&](int64_t v) {
676 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 });
678 newExprs.append(expandedExprs.begin(), expandedExprs.end());
679 }
680 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
681 indexingMap.getNumSymbols(), newExprs,
682 builder.getContext());
683}
684
685/// Return the shape and type of the operand/result to use in the expanded op
686/// given the type in the original op.
687static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
689 const ExpansionInfo &expansionInfo) {
690 SmallVector<OpFoldResult> expandedShape;
691 for (AffineExpr expr : indexingMap.getResults()) {
692 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 ArrayRef<OpFoldResult> dimExpansion =
694 expansionInfo.getExpandedShapeOfDim(dim);
695 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
696 }
697 SmallVector<int64_t> expandedStaticShape;
698 std::tie(expandedStaticShape, std::ignore) =
699 decomposeMixedValues(expandedShape);
700 return {expandedShape, RankedTensorType::get(expandedStaticShape,
701 originalType.getElementType())};
702}
703
704/// Returns the reassociation maps to use in the `tensor.expand_shape`
705/// operation to convert the operands of the original operation to operands of
706/// the expanded operation. The same method is used to compute the
707/// `tensor.collapse_shape` used to collapse the result of the expanded
708/// op to get the value that can replace all uses of the results of the original
709/// op.
710static SmallVector<ReassociationIndices>
712 const ExpansionInfo &expansionInfo) {
714 unsigned numReshapeDims = 0;
715 for (AffineExpr expr : indexingMap.getResults()) {
716 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
717 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
719 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
720 reassociation.emplace_back(std::move(indices));
721 numReshapeDims += numExpandedDims;
722 }
723 return reassociation;
724}
725
726/// Update the body of an expanded linalg operation having index semantics. The
727/// indices of the original operation need to be recovered by linearizing the
728/// indices of the correspoding dimensions of the expanded operation. For now it
729/// is assumed that the shapes of the expanded operation needed for
730/// linearization are static.
732 Location loc, Region &fusedRegion,
733 const ExpansionInfo &expansionInfo) {
734 // Replace the original indices by the linearization of the expanded indices.
735 for (IndexOp indexOp :
736 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
737 ArrayRef<int64_t> expandedDims =
738 expansionInfo.getExpandedDims(indexOp.getDim());
739 assert(!expandedDims.empty() && "expected valid expansion info");
740
741 // Skip index operations that are not affected by the expansion.
742 if (expandedDims.size() == 1 &&
743 expandedDims.front() == (int64_t)indexOp.getDim())
744 continue;
745
746 // Linearize the expanded indices of the original index dimension.
747 OpBuilder::InsertionGuard guard(rewriter);
748 rewriter.setInsertionPointAfter(indexOp);
749 ArrayRef<OpFoldResult> expandedDimsShape =
750 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 SmallVector<Value> expandedIndices;
752 expandedIndices.reserve(expandedDims.size() - 1);
753 llvm::transform(
754 expandedDims.drop_front(), std::back_inserter(expandedIndices),
755 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 OpFoldResult newIndex =
757 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
758 for (auto [expandedShape, expandedIndex] :
759 llvm::zip(expandedDimsShape, expandedIndices)) {
760 AffineExpr idx, acc, shape;
761 bindDims(rewriter.getContext(), idx, acc);
762 bindSymbols(rewriter.getContext(), shape);
764 rewriter, indexOp.getLoc(), idx + acc * shape,
765 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
766 }
767 Value newIndexVal =
768 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
769 rewriter.replaceOp(indexOp, newIndexVal);
770 }
771}
772
773// Create an expanded transpose op.
774// the reassociation map is already permuted hence we inverse permute and then
775// flatten it. Then we inverse permute it again to get the final expanded
776// transpose permutation. For example,
777//
778// permutation = [2, 0, 1]
779// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
780//
781// inverse permutation = [1, 2, 0]
782// applied to reassocation_map and then flattened becomes
783// flatened permutation = [2, 3, 4, 5, 0, 1]
784// final permuation is the inverse of the flattened permutation.
785//
786// Becomes
787//
788// permutation=[4, 5, 0, 1, 2, 3]
789
791 TransposeOp transposeOp,
792 Value expandedInput, Value output,
793 ExpansionInfo &expansionInfo) {
794 SmallVector<int64_t> newPerm;
795 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
796 auto reassoc = expansionInfo.getExpandedDims(perm);
797 for (int64_t dim : reassoc) {
798 newPerm.push_back(dim);
799 }
800 }
801 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
802 output, invertPermutationVector(newPerm));
803}
804
805// Create an expanded generic op.
807 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
808 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
809 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
810 // The iterator types of the expanded op are all parallel.
812 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813
814 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
815 for (auto j : expansionInfo.getExpandedDims(i))
816 iteratorTypes[j] = type;
817
818 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
819 expandedOpOperands, outputs,
820 expandedOpIndexingMaps, iteratorTypes);
821
822 Region &fusedRegion = fused->getRegion(0);
823 Region &originalRegion = linalgOp->getRegion(0);
824 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
825
826 // Update the index accesses after the expansion.
827 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
828 expansionInfo);
829
830 return fused;
831}
832
833// Create an expanded fused op that retains the name for certain ops
834// such as fill, copy and transpose and produce a generic op for
835// rest of linalg ops.
836static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
837 TypeRange resultTypes,
838 ArrayRef<Value> expandedOpOperands,
839 ArrayRef<Value> outputs,
840 ArrayRef<AffineMap> expandedOpIndexingMaps,
841 ExpansionInfo &expansionInfo) {
842
843 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
844 .Case([&](TransposeOp transposeOp) {
845 return createExpandedTransposeOp(rewriter, transposeOp,
846 expandedOpOperands[0], outputs[0],
847 expansionInfo);
848 })
849 .Case<FillOp, CopyOp>([&](Operation *op) {
850 return clone(rewriter, linalgOp, resultTypes,
851 llvm::to_vector(llvm::concat<Value>(
852 llvm::to_vector(expandedOpOperands),
853 llvm::to_vector(outputs))));
854 })
855 .Default([&](Operation *op) {
856 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
857 expandedOpOperands, outputs,
858 expansionInfo, expandedOpIndexingMaps);
859 });
860}
861
862/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
863/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
864/// that those conditions have been satisfied.
865static std::optional<SmallVector<Value>>
866fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
867 OpOperand *fusableOpOperand,
868 PatternRewriter &rewriter) {
869 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
870 "preconditions for fuse operation failed");
871
872 Location loc = linalgOp.getLoc();
873 SmallVector<OpFoldResult> expandedShape;
874 SmallVector<AffineMap, 4> reassociationIndices;
875 Value src;
876 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
877 // Try to move the dynamic dimensions in output shape before the `linalgOp`
878 // to maintain SSA validity
879 if (failed(moveValueDefinitions(
880 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
881 return std::nullopt;
882
883 expandedShape = expandingReshapeOp.getMixedOutputShape();
884 reassociationIndices = expandingReshapeOp.getReassociationMaps();
885 src = expandingReshapeOp.getSrc();
886 } else {
887 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
888 if (!collapsingReshapeOp)
889 return std::nullopt;
890
891 expandedShape = tensor::getMixedSizes(
892 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
893 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
894 src = collapsingReshapeOp.getSrc();
895 }
896
897 ExpansionInfo expansionInfo;
898 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
899 reassociationIndices, expandedShape,
900 rewriter)))
901 return std::nullopt;
902
903 SmallVector<AffineMap, 4> expandedOpIndexingMaps =
904 llvm::map_to_vector<4>(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
905 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
906 });
907
908 // Set insertion point to the generic op.
909 OpBuilder::InsertionGuard g(rewriter);
910 rewriter.setInsertionPoint(linalgOp);
911
912 SmallVector<Value> expandedOpOperands;
913 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
914 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
915 if (opOperand == fusableOpOperand) {
916 expandedOpOperands.push_back(src);
917 continue;
918 }
919 if (auto opOperandType =
920 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
921 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 SmallVector<OpFoldResult> expandedOperandShape;
923 RankedTensorType expandedOperandType;
924 std::tie(expandedOperandShape, expandedOperandType) =
925 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
926 if (expandedOperandType != opOperand->get().getType()) {
927 // Reshape the operand to get the right type.
929 getReassociationForExpansion(indexingMap, expansionInfo);
930 if (failed(reshapeLikeShapesAreCompatible(
931 [&](const Twine &msg) {
932 return rewriter.notifyMatchFailure(linalgOp, msg);
933 },
934 opOperandType.getShape(), expandedOperandType.getShape(),
935 reassociation,
936 /*isExpandingReshape=*/true)))
937 return std::nullopt;
938 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
939 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
940 expandedOperandShape));
941 continue;
942 }
943 }
944 expandedOpOperands.push_back(opOperand->get());
945 }
946
947 SmallVector<Value> outputs;
948 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
949 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
950 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 SmallVector<OpFoldResult> expandedOutputShape;
952 RankedTensorType expandedOutputType;
953 std::tie(expandedOutputShape, expandedOutputType) =
954 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
955 if (expandedOutputType != opOperand.get().getType()) {
957 getReassociationForExpansion(indexingMap, expansionInfo);
958 if (failed(reshapeLikeShapesAreCompatible(
959 [&](const Twine &msg) {
960 return rewriter.notifyMatchFailure(linalgOp, msg);
961 },
962 opOperandType.getShape(), expandedOutputType.getShape(),
963 reassociation,
964 /*isExpandingReshape=*/true)))
965 return std::nullopt;
966 outputs.push_back(tensor::ExpandShapeOp::create(
967 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
968 expandedOutputShape));
969 } else {
970 outputs.push_back(opOperand.get());
971 }
972 }
973
974 TypeRange resultTypes = ValueRange(outputs).getTypes();
975 Operation *fusedOp =
976 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
977 outputs, expandedOpIndexingMaps, expansionInfo);
978 // Reshape the result values to their original shape if this is a collapsing
979 // reshape folded into its consumer.
980 SmallVector<Value> resultVals;
981 for (OpResult opResult : linalgOp->getOpResults()) {
982 int64_t resultNumber = opResult.getResultNumber();
983 if (resultTypes[resultNumber] != opResult.getType()) {
986 linalgOp.getMatchingIndexingMap(
987 linalgOp.getDpsInitOperand(resultNumber)),
988 expansionInfo);
989 resultVals.push_back(tensor::CollapseShapeOp::create(
990 rewriter, linalgOp.getLoc(), opResult.getType(),
991 fusedOp->getResult(resultNumber), reassociation));
992 } else {
993 resultVals.push_back(fusedOp->getResult(resultNumber));
994 }
995 }
996 // Assuming a single result.
997 return resultVals;
998}
999
1000namespace {
1001
1002/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1003/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1004/// in the consumer is expanded.
1005class FoldWithProducerReshapeOpByExpansion
1006 : public OpInterfaceRewritePattern<LinalgOp> {
1007public:
1008 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1009 ControlFusionFn foldReshapes,
1010 PatternBenefit benefit = 1)
1011 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1012 controlFoldingReshapes(std::move(foldReshapes)) {}
1013
1014 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 PatternRewriter &rewriter) const override {
1016 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1017 tensor::CollapseShapeOp reshapeOp =
1018 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1019 if (!reshapeOp)
1020 continue;
1021 // Fold only if
1022 // - The tensor reshape op is folding.
1023 // - All constraints of fusing with reshape by expansion are met.
1024 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1025 (!controlFoldingReshapes(opOperand)))
1026 continue;
1027
1028 std::optional<SmallVector<Value>> replacementValues =
1029 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1030 if (!replacementValues)
1031 return failure();
1032 rewriter.replaceOp(linalgOp, *replacementValues);
1033 return success();
1034 }
1035 return failure();
1036 }
1037
1038private:
1039 ControlFusionFn controlFoldingReshapes;
1040};
1041
1042/// Carries information about a padded dimension.
1043struct PadDimInfo {
1044 // The resulting shape after padding each dimension.
1045 SmallVector<int64_t> paddedShape;
1046
1047 // Low and high padding amounts for each dimension.
1048 SmallVector<OpFoldResult> lowPad;
1049 SmallVector<OpFoldResult> highPad;
1050};
1051
1052/// Computes the expanded padding information for the given pad operation based
1053/// on the provided expanded shape and reassociation indices. Returns a list of
1054/// PadDimInfo containing the low and high padding amounts and the padded
1055/// size for each dimension, or failure if the expansion is not possible.
1056static FailureOr<PadDimInfo>
1057computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1058 ArrayRef<ReassociationIndices> reassociations,
1059 PatternRewriter &rewriter) {
1060 // If the padding value depends on the index values of the pad operation,
1061 // then it may not be valid to expand the dimensions, since it will change
1062 // the index values on which the padding value depends. This is not currently
1063 // supported by the pad expansion patterns, but it could be implemented
1064 // similarly to the expansion of linalg.generic ops with linalg.index ops in
1065 // the body, as is done in `updateExpandedGenericOpRegion`.
1066 if (!padOp.getConstantPaddingValue())
1067 return failure();
1068
1069 // Expanded dimensions cannot have padding because the resulting padding may
1070 // not be representable by a tensor.pad op. There are some special cases where
1071 // it is possible (like expanding unit dims), but supporting these cases is
1072 // NYI, so disallow it for now.
1073 ArrayRef<int64_t> low = padOp.getStaticLow();
1074 ArrayRef<int64_t> high = padOp.getStaticHigh();
1075 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1076 if (reInd.size() != 1 && (l != 0 || h != 0))
1077 return failure();
1078 }
1079
1080 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1081 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1082 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1083 PadDimInfo padDimInfo;
1084 padDimInfo.paddedShape.assign(expandedShape);
1085 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1086 padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1087 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1088 if (reInd.size() == 1) {
1089 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1090 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1091 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1092 }
1093 }
1094
1095 return padDimInfo;
1096}
1097
1098class FoldPadWithProducerReshapeOpByExpansion
1099 : public OpRewritePattern<tensor::PadOp> {
1100public:
1101 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1102 ControlFusionFn foldReshapes,
1103 PatternBenefit benefit = 1)
1104 : OpRewritePattern<tensor::PadOp>(context, benefit),
1105 controlFoldingReshapes(std::move(foldReshapes)) {}
1106
1107 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1108 PatternRewriter &rewriter) const override {
1109 tensor::CollapseShapeOp reshapeOp =
1110 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1111 if (!reshapeOp)
1112 return failure();
1113
1114 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1115 return rewriter.notifyMatchFailure(padOp,
1116 "fusion blocked by control function");
1117 }
1118
1119 RankedTensorType expandedType = reshapeOp.getSrcType();
1120 SmallVector<ReassociationIndices> reassociations =
1121 reshapeOp.getReassociationIndices();
1122 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1123 padOp, expandedType.getShape(), reassociations, rewriter);
1124 if (failed(maybeExpandedPadding))
1125 return failure();
1126 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1127
1128 Location loc = padOp->getLoc();
1129 RankedTensorType expandedPaddedType =
1130 padOp.getResultType().clone(expandedPadding.paddedShape);
1131
1132 auto newPadOp = tensor::PadOp::create(
1133 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1134 expandedPadding.lowPad, expandedPadding.highPad,
1135 padOp.getConstantPaddingValue(), padOp.getNofold());
1136
1137 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1138 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1139
1140 return success();
1141 }
1142
1143private:
1144 ControlFusionFn controlFoldingReshapes;
1145};
1146
1147class FoldReshapeWithProducerPadOpByExpansion
1148 : public OpRewritePattern<tensor::ExpandShapeOp> {
1149public:
1150 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1151 ControlFusionFn foldReshapes,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1154 controlFoldingReshapes(std::move(foldReshapes)) {}
1155
1156 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1157 PatternRewriter &rewriter) const override {
1158 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1159 if (!padOp)
1160 return failure();
1161
1162 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1163 return rewriter.notifyMatchFailure(expandOp,
1164 "fusion blocked by control function");
1165 }
1166
1167 RankedTensorType expandedType = expandOp.getResultType();
1168 SmallVector<ReassociationIndices> reassociations =
1169 expandOp.getReassociationIndices();
1170 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1171 padOp, expandedType.getShape(), reassociations, rewriter);
1172 if (failed(maybeExpandedPadding))
1173 return failure();
1174 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1175
1176 Location loc = expandOp->getLoc();
1177 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1178 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1179 rewriter.setInsertionPointAfterValue(padOp.getSource());
1180 SmallVector<OpFoldResult> padSrcSizes =
1181 tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1182 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1183 // We know that any reassociation with multiple dims is not padded because
1184 // of the requirements of computeExpandedPadding.
1185 if (reInd.size() == 1) {
1186 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1187 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1188 }
1189 }
1190 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1191 auto newExpandOp = tensor::ExpandShapeOp::create(
1192 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1193 newExpandedSizes);
1194 RankedTensorType expandedPaddedType =
1195 padOp.getResultType().clone(expandedPadding.paddedShape);
1196 rewriter.setInsertionPoint(expandOp);
1197 auto newPadOp = tensor::PadOp::create(
1198 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1199 expandedPadding.lowPad, expandedPadding.highPad,
1200 padOp.getConstantPaddingValue(), padOp.getNofold());
1201
1202 rewriter.replaceOp(expandOp, newPadOp.getResult());
1203
1204 return success();
1205 }
1206
1207private:
1208 ControlFusionFn controlFoldingReshapes;
1209};
1210
1211/// Pattern to fold a tensor.expand_shape op with its producer generic op
1212/// by expanding the dimensionality of the loop in the producer op.
1213struct FoldReshapeWithGenericOpByExpansion
1214 : public OpRewritePattern<tensor::ExpandShapeOp> {
1215
1216 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1217 ControlFusionFn foldReshapes,
1218 PatternBenefit benefit = 1)
1219 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1220 controlFoldingReshapes(std::move(foldReshapes)) {}
1221
1222 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1223 PatternRewriter &rewriter) const override {
1224 // Fold only if all constraints of fusing with reshape by expansion are met.
1225 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1226 if (!producerResult) {
1227 return rewriter.notifyMatchFailure(reshapeOp,
1228 "source not produced by an operation");
1229 }
1230
1231 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1232 if (!producer) {
1233 return rewriter.notifyMatchFailure(reshapeOp,
1234 "producer not a generic op");
1235 }
1236
1238 producer,
1239 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1240 return rewriter.notifyMatchFailure(
1241 reshapeOp, "failed preconditions of fusion with producer generic op");
1242 }
1243
1244 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1245 return rewriter.notifyMatchFailure(reshapeOp,
1246 "fusion blocked by control function");
1247 }
1248
1249 std::optional<SmallVector<Value>> replacementValues =
1251 producer, reshapeOp,
1252 producer.getDpsInitOperand(producerResult.getResultNumber()),
1253 rewriter);
1254 if (!replacementValues) {
1255 return rewriter.notifyMatchFailure(reshapeOp,
1256 "fusion by expansion failed");
1257 }
1258
1259 // Find the replacement for the reshape op. Since the replacements have the
1260 // same type as the returns of the original generic op, the consumer reshape
1261 // op can be replaced by the source of the collapse_shape op that defines
1262 // the replacement.
1263 Value reshapeReplacement =
1264 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1265 .getResultNumber()];
1266 if (auto collapseOp =
1267 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1268 reshapeReplacement = collapseOp.getSrc();
1269 }
1270 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1271 rewriter.replaceOp(producer, *replacementValues);
1272 return success();
1273 }
1274
1275private:
1276 ControlFusionFn controlFoldingReshapes;
1277};
1278} // namespace
1279
1280//===---------------------------------------------------------------------===//
1281// Methods and patterns to fuse reshape with linalg.generic operations by
1282// contraction of dimensions.
1283//===---------------------------------------------------------------------===//
1284
1285/// For a given list of indices in the range of the `indexingMap` that are
1286/// folded, return the indices of the corresponding domain. Return
1287/// `std::nullopt` on failure. Ensures that all the elements of the returned
1288/// reassociation are distinct.
1291 ReassociationIndicesRef rangeReassociation) {
1292 assert(indexingMap.isProjectedPermutation() &&
1293 "expected projected permutation");
1294
1295 ReassociationIndices domainReassociation =
1296 llvm::map_to_vector<4>(rangeReassociation, [&](int64_t pos) -> int64_t {
1297 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1298 });
1299 // The projected permutation semantics ensures that there is no repetition of
1300 // the domain indices.
1301 return domainReassociation;
1302}
1303
1304/// For a given `dimSequence`, check if the sequence is conserved in the
1305/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1306/// Non-existence of the sequence returns true as well.
1308 ReassociationIndicesRef dimSequence) {
1309 assert(!dimSequence.empty() &&
1310 "expected non-empty list for dimension sequence");
1311
1312 // Dimension sequences can only be preserved in projected permutation maps.
1313 if (!indexingMap.isProjectedPermutation()) {
1314 return false;
1315 }
1316
1317 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1318 sequenceElements.insert_range(dimSequence);
1319
1320 unsigned dimSequenceStart = dimSequence[0];
1321 for (const auto &expr : enumerate(indexingMap.getResults())) {
1322 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1323 // 1. Check if this start of the sequence.
1324 if (dimInMapStart == dimSequenceStart) {
1325 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1326 return false;
1327 // 1a. Check if sequence is preserved.
1328 for (const auto &dimInSequence : enumerate(dimSequence)) {
1329 unsigned dimInMap =
1330 cast<AffineDimExpr>(
1331 indexingMap.getResult(expr.index() + dimInSequence.index()))
1332 .getPosition();
1333 if (dimInMap != dimInSequence.value())
1334 return false;
1335 }
1336 // Found the sequence. Projected permutation
1337 // enforces that all AffineDimExprs in the result are unique, so no
1338 // further checks are needed.
1339 return true;
1340 }
1341 // 2. If position in the expr (which is of type AffineDimExpr) is part
1342 // of sequence, return false here. This implies the entire sequence does not
1343 // exist in the indexing map.
1344 if (sequenceElements.count(dimInMapStart))
1345 return false;
1346 }
1347 // 3. No element of sequence found. Return true.
1348 return true;
1349}
1350
1353 return llvm::all_of(maps, [&](AffineMap map) {
1354 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1355 return isDimSequencePreserved(map, dimSequence);
1356 });
1357 });
1358}
1359
1360// Return the list of dimensions of the iteration domain that can be
1361// collapsed to allow for fusion with the a producer that is an expand_shape
1362// operation. If all dimensions created by expansion can be collapsed in the
1363// iteration space then the reshape is defunct.
1364//
1365// Example:
1366//
1367// ```mlir
1368// #map = affine_map<(d0, d1) -> (d0, d1)>
1369// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1370// %2 = tensor.empty [..] : tensor<?x4xf32>
1371// %3 = linalg.generic {
1372// indexing_maps = [#map, #map],
1373// iterator_types = ["parallel" ,"parallel"]}
1374// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1375// ```
1376//
1377// can be fused by collapsing the dimensions of the iteration space.
1378//
1379// ```mlir
1380// #map = affine_map<(d0) -> (d0)>
1381// %2 = tensor.empty [..] : tensor<?xf32>
1382// %3 = linalg.generic {
1383// indexing_maps = [#map, #map],
1384// iterator_types = ["parallel"]}
1385// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1386// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1387// ```
1388//
1389// In the following example,
1390//
1391// ```mlir
1392// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1393// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1394// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1395// %2 = tensor.empty [..] : tensor<4x?xf32>
1396// %2 = linalg.generic {
1397// indexing_maps = [#map0, #map1],
1398// iterator_types = ["parallel" ,"parallel"]}
1399// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1400// ```
1401//
1402// the reshape cannot be fused with the generic op by collapsing the op
1403// dimensions since the indexing maps will have to contain mods and divs
1404// to preserve the accesses pattern. When no dimensions of the iteration
1405// space are collapsable and empty vector is returned.
1407getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1408 ArrayRef<ReassociationIndices> reassociation) {
1409 // Some basic checks for this fusion to be valid.
1410 if (!genericOp.hasPureTensorSemantics())
1411 return {};
1412
1413 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1414 return map.isProjectedPermutation();
1415 })) {
1416 return {};
1417 }
1418
1419 // Compute all the loops with the reduction iterator types.
1420 SmallVector<unsigned> reductionDims;
1421 genericOp.getReductionDims(reductionDims);
1422
1423 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1424 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1425 auto iteratorTypes = genericOp.getIteratorTypesArray();
1426 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1427 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1428 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1429
1430 // Ignore dims that are not folded.
1431 if (foldedRangeDims.size() == 1)
1432 continue;
1433
1434 ReassociationIndices foldedIterationSpaceDims =
1435 getDomainReassociation(indexingMap, foldedRangeDims);
1436
1437 // Check that the folded iteration dims do not contain already processed
1438 // dims.
1439 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1440 return processedIterationDims.count(dim);
1441 }))
1442 continue;
1443
1444 // Check that all folded iterator types are all parallel or all reductions.
1445 utils::IteratorType startIteratorType =
1446 iteratorTypes[foldedIterationSpaceDims[0]];
1447 if (!isParallelIterator(startIteratorType) &&
1448 !isReductionIterator(startIteratorType))
1449 continue;
1450 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1451 return iteratorTypes[dim] != startIteratorType;
1452 }))
1453 continue;
1454
1455 // If the folded dimensions correspond to a "reduction" iterator type,
1456 // the folded dimensions need to be "in-order". Strictly speaking this is
1457 // not necessary, for reductions that are associative and commutative, but
1458 // using a more strict definition of reduction for now.
1459 if (isReductionIterator(startIteratorType)) {
1460 bool isContiguous = false;
1461 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1462 // Move window in `reductionDims` to start of the folded iteration dims.
1463 if (startDim.value() != foldedIterationSpaceDims[0])
1464 continue;
1465 // If sizes doesnt match, trivial not contiguous. This condition should
1466 // not be hit.
1467 if (startDim.index() + foldedIterationSpaceDims.size() >
1468 reductionDims.size())
1469 break;
1470 // Check that the contiguity is maintained.
1471 isContiguous = true;
1472 for (const auto &foldedDim :
1473 llvm::enumerate(foldedIterationSpaceDims)) {
1474 if (reductionDims[foldedDim.index() + startDim.index()] !=
1475 foldedDim.value()) {
1476 isContiguous = false;
1477 break;
1478 }
1479 }
1480 break;
1481 }
1482 if (!isContiguous)
1483 continue;
1484 }
1485
1486 // Check that the sequence is preserved in all indexing maps.
1487 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1488 [&](AffineMap indexingMap) {
1489 return !isDimSequencePreserved(indexingMap,
1490 foldedIterationSpaceDims);
1491 }))
1492 continue;
1493
1494 processedIterationDims.insert_range(foldedIterationSpaceDims);
1495 iterationSpaceReassociation.emplace_back(
1496 std::move(foldedIterationSpaceDims));
1497 }
1498
1499 return iterationSpaceReassociation;
1500}
1501
1502/// Helper class to carry state while collapsing the `linalg.generic` op.
1503namespace {
1504class CollapsingInfo {
1505public:
1506 LogicalResult initialize(unsigned origNumLoops,
1507 ArrayRef<ReassociationIndices> foldedIterationDims) {
1508 llvm::SmallDenseSet<int64_t, 4> processedDims;
1509 // Find all the dims that are folded.
1510 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1511 if (foldedIterationDim.empty())
1512 continue;
1513 // If the folded dims contain dims already folded, that's illegal
1514 // specification. Repetition within a list is also illegal.
1515 for (auto dim : foldedIterationDim) {
1516 if (dim >= origNumLoops)
1517 return failure();
1518 if (processedDims.count(dim))
1519 return failure();
1520 processedDims.insert(dim);
1521 }
1522 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1523 foldedIterationDim.end());
1524 }
1525 if (processedDims.size() > origNumLoops)
1526 return failure();
1527
1528 // Add all the preserved dims of the original op as single
1529 // elements to `collapsedOpToOrigOpIterationDim`.
1530 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1531 if (processedDims.count(dim))
1532 continue;
1533 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1534 }
1535
1536 llvm::sort(collapsedOpToOrigOpIterationDim,
1538 return lhs[0] < rhs[0];
1539 });
1540 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1541 for (const auto &foldedDims :
1542 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1543 for (const auto &dim : enumerate(foldedDims.value()))
1544 origOpToCollapsedOpIterationDim[dim.value()] =
1545 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1546 }
1547 return success();
1548 }
1549
1550 /// Return mapping from collapsed loop domain to original loop domain.
1551 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1552 return collapsedOpToOrigOpIterationDim;
1553 }
1554
1555 /// Return mapping from original loop domain to collapsed loop domain. The
1556 /// mapping is a pair. First value is the dimension in the collapsed loop that
1557 /// the original loop is mapped to. Second is the relative position in folded
1558 /// list of this domain. For example if the original loop domain is 3D, and
1559 /// the collapsed loop domain is folding all of it, i.e.
1560 ///
1561 /// ```
1562 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1563 /// ```
1564 ///
1565 /// then
1566 ///
1567 /// ```
1568 /// origOpToCollapsedOpMapping[0] = {0, 0};
1569 /// origOpToCollapsedOpMapping[1] = {0, 1};
1570 /// origOpToCollapsedOpMapping[2] = {0, 2};
1571 /// origOpToCollapsedOpMapping[3] = {1, 0};
1572 /// origOpToCollapsedOpMapping[4] = {1, 1};
1573 /// ```
1574 ///
1575 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1576 return origOpToCollapsedOpIterationDim;
1577 }
1578
1579 /// Return the collapsed op iteration domain rank.
1580 unsigned getCollapsedOpIterationRank() const {
1581 return collapsedOpToOrigOpIterationDim.size();
1582 }
1583
1584private:
1585 /// Map from the iteration domain index in collapsed op to the iteration
1586 /// domain indices in the original op.
1587 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1588
1589 /// Map from iteration domain index in the original op to the iteration domain
1590 /// index in the collapsed op.
1591 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1592};
1593} // namespace
1594
1595/// Get the iterator types for the collapsed operation given the original
1596/// iterator types and collapsed dimensions.
1597static SmallVector<utils::IteratorType>
1598getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1599 const CollapsingInfo &collapsingInfo) {
1600 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1601 for (ReassociationIndicesRef foldedIterDims :
1602 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1603 assert(!foldedIterDims.empty() &&
1604 "reassociation indices expected to have non-empty sets");
1605 // Just pick the iterator type of the first folded dim. Pre-condition checks
1606 // expected to have checked that iterator types of all folded dimensions are
1607 // the same.
1608 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1609 }
1610 return collapsedIteratorTypes;
1611}
1612
1613/// Compute the indexing map in the collapsed op that corresponds to the given
1614/// `indexingMap` of the original operation.
1615static AffineMap
1616getCollapsedOpIndexingMap(AffineMap indexingMap,
1617 const CollapsingInfo &collapsingInfo) {
1618 MLIRContext *context = indexingMap.getContext();
1619 assert(indexingMap.isProjectedPermutation() &&
1620 "expected indexing map to be projected permutation");
1621 SmallVector<AffineExpr> resultExprs;
1622 auto origOpToCollapsedOpMapping =
1623 collapsingInfo.getOrigOpToCollapsedOpMapping();
1624 for (auto expr : indexingMap.getResults()) {
1625 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1626 // If the dim is not the first of the collapsed dim, do nothing.
1627 if (origOpToCollapsedOpMapping[dim].second != 0)
1628 continue;
1629 // The next n-dims are guaranteed to be collapsed. So just use the
1630 // iteration dimension of the collapsed op.
1631 resultExprs.push_back(
1632 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1633 }
1634 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1635 resultExprs, context);
1636}
1637
1638/// Return the `reassociation` indices to use to collapse the operand when the
1639/// iteration space of a generic op is collapsed.
1640static SmallVector<ReassociationIndices>
1641getOperandReassociation(AffineMap indexingMap,
1642 const CollapsingInfo &collapsingInfo) {
1643 unsigned counter = 0;
1644 SmallVector<ReassociationIndices> operandReassociation;
1645 auto origOpToCollapsedOpMapping =
1646 collapsingInfo.getOrigOpToCollapsedOpMapping();
1647 auto collapsedOpToOrigOpMapping =
1648 collapsingInfo.getCollapsedOpToOrigOpMapping();
1649 while (counter < indexingMap.getNumResults()) {
1650 unsigned dim =
1651 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1652 // This is the start of a collapsed dimensions of the iteration that
1653 // is gauranteed to be preserved in the indexing map. The number of folded
1654 // dims is obtained from the collapsed op to original op mapping.
1655 unsigned numFoldedDims =
1656 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1657 .size();
1658 if (origOpToCollapsedOpMapping[dim].second == 0) {
1659 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1660 operandReassociation.emplace_back(range.begin(), range.end());
1661 }
1662 counter += numFoldedDims;
1663 }
1664 return operandReassociation;
1665}
1666
1667/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1668static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1669 OpOperand *opOperand,
1670 const CollapsingInfo &collapsingInfo,
1671 OpBuilder &builder) {
1672 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1673 SmallVector<ReassociationIndices> operandReassociation =
1674 getOperandReassociation(indexingMap, collapsingInfo);
1675
1676 // If the number of entries in the reassociation for the operand is same as
1677 // the number of results of the indexing map, then nothing to do for this
1678 // operand.
1679 Value operand = opOperand->get();
1680 if (operandReassociation.size() == indexingMap.getNumResults())
1681 return operand;
1682
1683 // Insert a reshape to collapse the dimensions.
1684 if (isa<MemRefType>(operand.getType())) {
1685 return memref::CollapseShapeOp::create(builder, loc, operand,
1686 operandReassociation)
1687 .getResult();
1688 }
1689 return tensor::CollapseShapeOp::create(builder, loc, operand,
1690 operandReassociation)
1691 .getResult();
1692}
1693
1694/// Modify the `linalg.index` operations in the original generic op, to its
1695/// value in the collapsed operation.
1696static void generateCollapsedIndexingRegion(
1697 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1698 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1699 OpBuilder::InsertionGuard g(rewriter);
1700 rewriter.setInsertionPointToStart(block);
1701
1702 // Collect all the original index ops.
1703 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1704
1705 // For each folded dimension list resolve the original induction variable
1706 // values in terms of the folded dimension induction variable.
1707 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1708 // can be inverted to
1709 // i2 = i_{folded} % d2
1710 // i1 = (i_{folded} / d2) % d1
1711 // i0 = i_{folded} / (d1 * d2)
1712 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1713 for (auto foldedDims :
1714 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1715 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1716 Value newIndexVal =
1717 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1718 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1719 Value loopDim =
1720 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1721 indexReplacementVals[dim] =
1722 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1723 newIndexVal =
1724 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1725 }
1726 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1727 }
1728
1729 for (auto indexOp : indexOps) {
1730 auto dim = indexOp.getDim();
1731 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1732 }
1733}
1734
1735static void collapseOperandsAndResults(LinalgOp op,
1736 const CollapsingInfo &collapsingInfo,
1737 RewriterBase &rewriter,
1738 SmallVectorImpl<Value> &inputOperands,
1739 SmallVectorImpl<Value> &outputOperands,
1740 SmallVectorImpl<Type> &resultTypes) {
1741 Location loc = op->getLoc();
1742 inputOperands =
1743 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1744 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1745 rewriter);
1746 });
1747
1748 // Get the output operands and result types.
1749 resultTypes.reserve(op.getNumDpsInits());
1750 outputOperands.reserve(op.getNumDpsInits());
1751 for (OpOperand &output : op.getDpsInitsMutable()) {
1752 Value newOutput =
1753 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1754 outputOperands.push_back(newOutput);
1755 // If the op has "buffer semantics", then the init operands are ranked
1756 // memrefs and the op has no results.
1757 if (!op.hasPureBufferSemantics())
1758 resultTypes.push_back(newOutput.getType());
1759 }
1760}
1761
1762/// Clone a `LinalgOp` to a collapsed version of same name
1763template <typename OpTy>
1764static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1765 const CollapsingInfo &collapsingInfo) {
1766 return nullptr;
1767}
1768
1769/// Collapse any `LinalgOp` that does not require any specialization such as
1770/// indexing_maps, iterator_types, etc.
1771template <>
1772LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1773 const CollapsingInfo &collapsingInfo) {
1774 SmallVector<Value> inputOperands, outputOperands;
1775 SmallVector<Type> resultTypes;
1776 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1777 outputOperands, resultTypes);
1778
1779 return clone(
1780 rewriter, origOp, resultTypes,
1781 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1782}
1783
1784/// Collapse a `GenericOp`
1785template <>
1786GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1787 GenericOp origOp,
1788 const CollapsingInfo &collapsingInfo) {
1789 SmallVector<Value> inputOperands, outputOperands;
1790 SmallVector<Type> resultTypes;
1791 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1792 outputOperands, resultTypes);
1793 SmallVector<AffineMap> indexingMaps(
1794 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1795 return getCollapsedOpIndexingMap(map, collapsingInfo);
1796 }));
1797
1798 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1799 origOp.getIteratorTypesArray(), collapsingInfo));
1800
1801 GenericOp collapsedOp = linalg::GenericOp::create(
1802 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1803 indexingMaps, iteratorTypes,
1804 [](OpBuilder &builder, Location loc, ValueRange args) {});
1805 Block *origOpBlock = &origOp->getRegion(0).front();
1806 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1807 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1808 collapsedOpBlock->getArguments());
1809 return collapsedOp;
1810}
1811
1812static LinalgOp createCollapsedOp(LinalgOp op,
1813 const CollapsingInfo &collapsingInfo,
1814 RewriterBase &rewriter) {
1815 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1816 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1817 }
1818 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1819}
1820
1821/// Implementation of fusion with reshape operation by collapsing dimensions.
1822FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1823 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1824 RewriterBase &rewriter) {
1825 // Bail on trivial no-op cases.
1826 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1827 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1828 return foldedDims.size() <= 1;
1829 }))
1830 return failure();
1831
1832 CollapsingInfo collapsingInfo;
1833 if (failed(
1834 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1835 return rewriter.notifyMatchFailure(
1836 op, "illegal to collapse specified dimensions");
1837 }
1838
1839 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1840 if (hasPureBufferSemantics &&
1841 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1842 MemRefType memRefToCollapse =
1843 dyn_cast<MemRefType>(opOperand.get().getType());
1844 if (!memRefToCollapse)
1845 return true;
1846
1847 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1848 SmallVector<ReassociationIndices> operandReassociation =
1849 getOperandReassociation(indexingMap, collapsingInfo);
1850 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1851 memRefToCollapse, operandReassociation);
1852 }))
1853 return rewriter.notifyMatchFailure(op,
1854 "memref is not guaranteed collapsible");
1855
1856 // Bail on non-canonical ranges.
1857 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1858 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1859 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1860 return cast<IntegerAttr>(attr).getInt() == value;
1861 llvm::APInt actual;
1862 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1863 actual.getSExtValue() == value;
1864 };
1865 if (!llvm::all_of(loopRanges, [&](Range range) {
1866 return opFoldIsConstantValue(range.offset, 0) &&
1867 opFoldIsConstantValue(range.stride, 1);
1868 })) {
1869 return rewriter.notifyMatchFailure(
1870 op, "expected all loop ranges to have zero start and unit stride");
1871 }
1872
1873 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1874
1875 Location loc = op->getLoc();
1876 SmallVector<OpFoldResult> loopBound =
1877 llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1878
1879 if (collapsedOp.hasIndexSemantics()) {
1880 // Collect the loop range of the generic op.
1881 OpBuilder::InsertionGuard g(rewriter);
1882 rewriter.setInsertionPoint(collapsedOp);
1883 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1884 collapsingInfo, loopBound, rewriter);
1885 }
1886
1887 // Insert expanding reshape for the result to get back the original result
1888 // type.
1889 SmallVector<Value> results;
1890 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1891 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1892 auto originalResultType =
1893 cast<ShapedType>(originalResult.value().getType());
1894 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1895 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1896 AffineMap indexingMap =
1897 op.getIndexingMapMatchingResult(originalResult.value());
1898 SmallVector<ReassociationIndices> reassociation =
1899 getOperandReassociation(indexingMap, collapsingInfo);
1900 assert(
1901 indexingMap.isProjectedPermutation() &&
1902 "Expected indexing map to be a projected permutation for collapsing");
1903 SmallVector<OpFoldResult> resultShape =
1904 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1905 Value result;
1906 if (isa<MemRefType>(collapsedOpResult.getType())) {
1907 result = memref::ExpandShapeOp::create(
1908 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1909 resultShape);
1910 } else {
1911 result = tensor::ExpandShapeOp::create(
1912 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1913 resultShape);
1914 }
1915 results.push_back(result);
1916 } else {
1917 results.push_back(collapsedOpResult);
1918 }
1919 }
1920 return CollapseResult{results, collapsedOp};
1921}
1922
1923namespace {
1924
1925/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1926/// contracting dimensions of the loop.
1927class FoldWithProducerReshapeOpByCollapsing
1928 : public OpRewritePattern<GenericOp> {
1929public:
1930 // TODO : support fusion with all linalg ops, not just generic.
1931 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1932 ControlFusionFn foldReshapes,
1933 PatternBenefit benefit = 1)
1934 : OpRewritePattern<GenericOp>(context, benefit),
1935 controlFoldingReshapes(std::move(foldReshapes)) {}
1936
1937 LogicalResult matchAndRewrite(GenericOp genericOp,
1938 PatternRewriter &rewriter) const override {
1939 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1940 tensor::ExpandShapeOp reshapeOp =
1941 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1942 if (!reshapeOp)
1943 continue;
1944
1945 SmallVector<ReassociationIndices> collapsableIterationDims =
1946 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1947 reshapeOp.getReassociationIndices());
1948 if (collapsableIterationDims.empty() ||
1949 !controlFoldingReshapes(&opOperand)) {
1950 continue;
1951 }
1952
1953 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1954 genericOp, collapsableIterationDims, rewriter);
1955 if (!collapseResult) {
1956 return rewriter.notifyMatchFailure(
1957 genericOp, "failed to do the fusion by collapsing transformation");
1958 }
1959
1960 rewriter.replaceOp(genericOp, collapseResult->results);
1961 return success();
1962 }
1963 return failure();
1964 }
1965
1966private:
1967 ControlFusionFn controlFoldingReshapes;
1968};
1969
1970/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1971/// by expanding the dimensionality of the loop in the producer op.
1972struct FoldReshapeWithGenericOpByCollapsing
1973 : public OpRewritePattern<tensor::CollapseShapeOp> {
1974
1975 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1976 ControlFusionFn foldReshapes,
1977 PatternBenefit benefit = 1)
1978 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1979 controlFoldingReshapes(std::move(foldReshapes)) {}
1980
1981 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1982 PatternRewriter &rewriter) const override {
1983 // Fold only if all constraints of fusing with reshape by collapsing are
1984 // met.
1985 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1986 if (!producerResult) {
1987 return rewriter.notifyMatchFailure(reshapeOp,
1988 "source not produced by an operation");
1989 }
1990
1991 // TODO : support fusion with all linalg producers, not just generic.
1992 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1993 if (!producer) {
1994 return rewriter.notifyMatchFailure(reshapeOp,
1995 "producer not a generic op");
1996 }
1997
1998 SmallVector<ReassociationIndices> collapsableIterationDims =
2000 producer,
2001 producer.getDpsInitOperand(producerResult.getResultNumber()),
2002 reshapeOp.getReassociationIndices());
2003 if (collapsableIterationDims.empty()) {
2004 return rewriter.notifyMatchFailure(
2005 reshapeOp, "failed preconditions of fusion with producer generic op");
2006 }
2007
2008 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2009 return rewriter.notifyMatchFailure(reshapeOp,
2010 "fusion blocked by control function");
2011 }
2012
2013 // Set the insertion point after `producer` because there could be uses
2014 // of `producer` between it and the `tensor.collapse_shape` op.
2015 rewriter.setInsertionPointAfter(producer);
2016 std::optional<CollapseResult> collapseResult =
2017 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
2018 if (!collapseResult) {
2019 return rewriter.notifyMatchFailure(
2020 producer, "failed to do the fusion by collapsing transformation");
2021 }
2022
2023 rewriter.replaceOp(producer, collapseResult->results);
2024 return success();
2025 }
2026
2027private:
2028 ControlFusionFn controlFoldingReshapes;
2029};
2030
2031/// Computes the collapsed padding information for the given pad operation based
2032/// on the provided collapsed shape and reassociation indices. Returns a
2033/// PadDimInfo containing the low and high padding amounts and the collapsed
2034/// shape for each dimension, or failure if the collapse is not possible.
2035static FailureOr<PadDimInfo>
2036computeCollapsedPadding(tensor::PadOp padOp,
2037 ArrayRef<ReassociationIndices> reassociations,
2038 PatternRewriter &rewriter) {
2039 // If the padding value depends on the index values of the pad operation,
2040 // then it may not be valid to collapse the dimensions, since it will change
2041 // the index values on which the padding value depends. This is not currently
2042 // supported by the pad collapsing patterns, but it could be implemented
2043 // similarly to the collapsing of linalg.generic ops with linalg.index ops in
2044 // the body, as is done in `generateCollapsedIndexingRegion`.
2045 if (!padOp.getConstantPaddingValue())
2046 return failure();
2047
2048 // Collapsed dimensions cannot have padding because this can produce strided
2049 // padding that isn't representable by a tensor.pad op. There are some special
2050 // cases where it is possible (like collapsing unit dims), but supporting
2051 // these cases is NYI, so disallow it for now.
2052 ArrayRef<int64_t> low = padOp.getStaticLow();
2053 ArrayRef<int64_t> high = padOp.getStaticHigh();
2054 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2055 for (int64_t dim : reInd) {
2056 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2057 return failure();
2058 }
2059 }
2060
2061 // Initialize padding values for collapsed tensors with zeros
2062 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2063 PadDimInfo padDimInfo;
2064 padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2065 padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2066
2067 // Update padding for dimensions that are not being collapsed, and compute
2068 // the collapsed padded shape.
2069 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2070 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2071 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2072 if (reInd.size() == 1) {
2073 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2074 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2075 }
2076 SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2077 for (int64_t dim : reInd) {
2078 collapsedSize =
2079 collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2080 }
2081 padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2082 }
2083
2084 return padDimInfo;
2085}
2086
2087class FoldPadWithProducerReshapeOpByCollapsing
2088 : public OpRewritePattern<tensor::PadOp> {
2089public:
2090 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2091 ControlFusionFn foldReshapes,
2092 PatternBenefit benefit = 1)
2093 : OpRewritePattern<tensor::PadOp>(context, benefit),
2094 controlFoldingReshapes(std::move(foldReshapes)) {}
2095
2096 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2097 PatternRewriter &rewriter) const override {
2098 tensor::ExpandShapeOp reshapeOp =
2099 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2100 if (!reshapeOp)
2101 return failure();
2102
2103 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2104 return rewriter.notifyMatchFailure(padOp,
2105 "fusion blocked by control function");
2106 }
2107
2108 SmallVector<ReassociationIndices> reassociations =
2109 reshapeOp.getReassociationIndices();
2110 FailureOr<PadDimInfo> maybeCollapsedPadding =
2111 computeCollapsedPadding(padOp, reassociations, rewriter);
2112 if (failed(maybeCollapsedPadding))
2113 return failure();
2114 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2115
2116 SmallVector<OpFoldResult> expandedPaddedSizes =
2117 reshapeOp.getMixedOutputShape();
2118 AffineExpr d0, d1, d2;
2119 bindDims(rewriter.getContext(), d0, d1, d2);
2120 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
2121 Location loc = reshapeOp->getLoc();
2122 for (auto [reInd, l, h] :
2123 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2124 collapsedPadding.highPad)) {
2125 if (reInd.size() == 1) {
2126 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2127 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2128 }
2129 }
2130
2131 RankedTensorType collapsedPaddedType =
2132 padOp.getType().clone(collapsedPadding.paddedShape);
2133 auto newPadOp = tensor::PadOp::create(
2134 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2135 collapsedPadding.lowPad, collapsedPadding.highPad,
2136 padOp.getConstantPaddingValue(), padOp.getNofold());
2137
2138 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2139 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2140 expandedPaddedSizes);
2141
2142 return success();
2143 }
2144
2145private:
2146 ControlFusionFn controlFoldingReshapes;
2147};
2148
2149class FoldReshapeWithProducerPadOpByCollapsing
2150 : public OpRewritePattern<tensor::CollapseShapeOp> {
2151public:
2152 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2153 ControlFusionFn foldReshapes,
2154 PatternBenefit benefit = 1)
2155 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2156 controlFoldingReshapes(std::move(foldReshapes)) {}
2157
2158 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2159 PatternRewriter &rewriter) const override {
2160 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2161 if (!padOp)
2162 return failure();
2163
2164 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2165 return rewriter.notifyMatchFailure(padOp,
2166 "fusion blocked by control function");
2167 }
2168
2169 SmallVector<ReassociationIndices> reassociations =
2170 reshapeOp.getReassociationIndices();
2171 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2172 FailureOr<PadDimInfo> maybeCollapsedPadding =
2173 computeCollapsedPadding(padOp, reassociations, rewriter);
2174 if (failed(maybeCollapsedPadding))
2175 return failure();
2176 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2177
2178 Location loc = reshapeOp->getLoc();
2179 auto newCollapseOp = tensor::CollapseShapeOp::create(
2180 rewriter, loc, padOp.getSource(), reassociations);
2181
2182 auto newPadOp = tensor::PadOp::create(
2183 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2184 collapsedPadding.lowPad, collapsedPadding.highPad,
2185 padOp.getConstantPaddingValue(), padOp.getNofold());
2186
2187 rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2188 return success();
2189 }
2190
2191private:
2192 ControlFusionFn controlFoldingReshapes;
2193};
2194
2195/// Pattern to collapse dimensions.
2196template <typename LinalgType>
2197class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2198public:
2199 CollapseLinalgDimensions(MLIRContext *context,
2200 GetCollapsableDimensionsFn collapseDimensions,
2201 PatternBenefit benefit = 1)
2202 : OpRewritePattern<LinalgType>(context, benefit),
2203 controlCollapseDimension(std::move(collapseDimensions)) {}
2204
2205 LogicalResult matchAndRewrite(LinalgType op,
2206 PatternRewriter &rewriter) const override {
2207 SmallVector<ReassociationIndices> collapsableIterationDims =
2208 controlCollapseDimension(op);
2209 if (collapsableIterationDims.empty())
2210 return failure();
2211
2212 // Check if the specified list of dimensions to collapse is a valid list.
2213 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2214 collapsableIterationDims)) {
2215 return rewriter.notifyMatchFailure(
2216 op, "specified dimensions cannot be collapsed");
2217 }
2218
2219 std::optional<CollapseResult> collapseResult =
2220 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2221 if (!collapseResult) {
2222 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2223 }
2224 rewriter.replaceOp(op, collapseResult->results);
2225 return success();
2226 }
2227
2228private:
2229 GetCollapsableDimensionsFn controlCollapseDimension;
2230};
2231
2232} // namespace
2233
2234//===---------------------------------------------------------------------===//
2235// Methods and patterns that fuse constants with linalg.generic operations.
2236//===---------------------------------------------------------------------===//
2237
2238namespace {
2239/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2240/// handle cases where the constant is not single-valued.
2241class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2242public:
2243 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2244 : OpRewritePattern<GenericOp>(context, benefit) {}
2245
2246 LogicalResult matchAndRewrite(GenericOp genericOp,
2247 PatternRewriter &rewriter) const override {
2248 if (!genericOp.hasPureTensorSemantics())
2249 return failure();
2250 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2251 Operation *def = opOperand->get().getDefiningOp();
2252 TypedAttr constantAttr;
2253 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2254 {
2255 DenseElementsAttr splatAttr;
2256 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2257 splatAttr.isSplat() &&
2258 splatAttr.getType().getElementType().isIntOrFloat()) {
2259 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2260 return true;
2261 }
2262 }
2263 {
2264 IntegerAttr intAttr;
2265 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2266 constantAttr = intAttr;
2267 return true;
2268 }
2269 }
2270 {
2271 FloatAttr floatAttr;
2272 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2273 constantAttr = floatAttr;
2274 return true;
2275 }
2276 }
2277 return false;
2278 };
2279
2280 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2281 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2282 continue;
2283
2284 // The operands and the indexing_maps of the fused operation the same as
2285 // the operands and indexing_maps of the generic operations with the
2286 // values at the constant index dropped.
2287 SmallVector<AffineMap> fusedIndexMaps;
2288 SmallVector<Value> fusedOperands;
2289 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2290 fusedIndexMaps.reserve(genericOp->getNumOperands());
2291 fusedOperands.reserve(genericOp.getNumDpsInputs());
2292 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2293 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2294 if (inputOperand == opOperand)
2295 continue;
2296 Value inputValue = inputOperand->get();
2297 fusedIndexMaps.push_back(
2298 genericOp.getMatchingIndexingMap(inputOperand));
2299 fusedOperands.push_back(inputValue);
2300 fusedLocs.push_back(inputValue.getLoc());
2301 }
2302 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2303 fusedIndexMaps.push_back(
2304 genericOp.getMatchingIndexingMap(&outputOperand));
2305
2306 // Check if the operation shapes to loops map is computable.
2307 if (!inversePermutation(
2308 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2309 return rewriter.notifyMatchFailure(
2310 genericOp, "fused op loop bound computation failed");
2311 }
2312
2313 // Create a constant scalar value from the splat constant.
2314 Value scalarConstant =
2315 arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2316
2317 SmallVector<Value> outputOperands = genericOp.getOutputs();
2318 auto fusedOp =
2319 GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2320 genericOp->getResultTypes(),
2321 /*inputs=*/fusedOperands,
2322 /*outputs=*/outputOperands,
2323 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2324 genericOp.getIteratorTypes(),
2325 /*doc=*/nullptr,
2326 /*library_call=*/nullptr);
2327
2328 // Map the block argument corresponding to the replaced argument with the
2329 // scalar constant.
2330 Region &region = genericOp->getRegion(0);
2331 Block &entryBlock = *region.begin();
2332 IRMapping mapping;
2333 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2334 scalarConstant);
2335 Region &fusedRegion = fusedOp->getRegion(0);
2336 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2337 mapping);
2338 rewriter.replaceOp(genericOp, fusedOp->getResults());
2339 return success();
2340 }
2341 return failure();
2342 }
2343};
2344
2345} // namespace
2346
2347//===---------------------------------------------------------------------===//
2348// Miscellaneous patterns that help fusion.
2349//===---------------------------------------------------------------------===//
2350
2351namespace {
2352/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2353/// value of the `outs` operand is not used within the op. This is only
2354/// implemented for `linalg.generic` operations for now, but should hold for all
2355/// linalg structured ops.
2356struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2357 using OpRewritePattern<GenericOp>::OpRewritePattern;
2358
2359 LogicalResult matchAndRewrite(GenericOp op,
2360 PatternRewriter &rewriter) const override {
2361 rewriter.startOpModification(op);
2362 bool modifiedOutput = false;
2363 Location loc = op.getLoc();
2364 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2365 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2366 Value operandVal = opOperand.get();
2367 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2368 if (!operandType)
2369 continue;
2370
2371 // If outs is sparse, leave it to the sparsifier.
2373 continue;
2374
2375 // If outs is already an `empty` operation, nothing to do.
2376 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2377 if (definingOp)
2378 continue;
2379 modifiedOutput = true;
2380 SmallVector<OpFoldResult> mixedSizes =
2381 tensor::getMixedSizes(rewriter, loc, operandVal);
2382 Value emptyTensor = tensor::EmptyOp::create(
2383 rewriter, loc, mixedSizes, operandType.getElementType());
2384 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2385 }
2386 }
2387 if (!modifiedOutput) {
2388 rewriter.cancelOpModification(op);
2389 return failure();
2390 }
2391 rewriter.finalizeOpModification(op);
2392 return success();
2393 }
2394};
2395
2396/// Fold linalg.fill into linalg.generic
2397struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2398 using OpRewritePattern<GenericOp>::OpRewritePattern;
2399
2400 LogicalResult matchAndRewrite(GenericOp genericOp,
2401 PatternRewriter &rewriter) const override {
2402 if (!genericOp.hasPureTensorSemantics())
2403 return failure();
2404 bool fillFound = false;
2405 Block &payload = genericOp.getRegion().front();
2406 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2407 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2408 continue;
2409 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2410 if (!fillOp)
2411 continue;
2412 fillFound = true;
2413 Value fillVal = fillOp.value();
2414 auto resultType =
2415 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2416 Value convertedVal =
2417 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2418 /*isUnsignedCast =*/false);
2419 rewriter.replaceAllUsesWith(
2420 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2421 }
2422 return success(fillFound);
2423 }
2424};
2425} // namespace
2426
2428 RewritePatternSet &patterns,
2429 const ControlFusionFn &controlFoldingReshapes) {
2430 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2431 controlFoldingReshapes);
2432 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2433 controlFoldingReshapes);
2434 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2435 controlFoldingReshapes);
2436 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2437 controlFoldingReshapes);
2438}
2439
2441 RewritePatternSet &patterns,
2442 const ControlFusionFn &controlFoldingReshapes) {
2443 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2444 controlFoldingReshapes);
2445 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2446 patterns.getContext(), controlFoldingReshapes);
2447 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2448 patterns.getContext(), controlFoldingReshapes);
2449 patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2450 controlFoldingReshapes);
2451}
2452
2454 RewritePatternSet &patterns,
2455 const ControlFusionFn &controlElementwiseOpsFusion) {
2456 auto *context = patterns.getContext();
2457 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2458 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2459 RemoveOutsDependency>(context);
2460 // Add the patterns that clean up dead operands and results.
2462}
2463
2465 RewritePatternSet &patterns,
2466 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2467 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2468 CollapseLinalgDimensions<linalg::CopyOp>>(
2469 patterns.getContext(), controlCollapseDimensions);
2470}
2471
2472//===---------------------------------------------------------------------===//
2473// Passes
2474//===---------------------------------------------------------------------===//
2475
2476namespace {
2477
2478/// Pass that fuses generic ops on tensors. Used only for testing.
2479// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2480// patterns added here heavily depends on the cost function used. Having an
2481// opinionated pass of this form is not recommended. Deprecate this pass in
2482// favor of test passes that check the functionality of each of the patterns
2483// added here individually.
2484struct LinalgElementwiseOpFusionPass
2485 : public impl::LinalgElementwiseOpFusionPassBase<
2486 LinalgElementwiseOpFusionPass> {
2487 using impl::LinalgElementwiseOpFusionPassBase<
2488 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2489 void runOnOperation() override {
2490 Operation *op = getOperation();
2491 MLIRContext *context = op->getContext();
2492 RewritePatternSet patterns(context);
2493
2494 // Add folding with reshape by expansion patterns.
2495 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2496 Operation *producer = fusedOperand->get().getDefiningOp();
2497 return producer && producer->hasOneUse();
2498 };
2499
2500 // Add elementwise op fusion patterns.
2501 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2502 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2504
2505 // General canonicalization patterns.
2506 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2507 GenericOp::getCanonicalizationPatterns(patterns, context);
2508 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2509 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2510 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2511 patterns);
2512
2513 // Add constant folding patterns.
2514 populateConstantFoldLinalgOperations(patterns, defaultControlFn);
2515
2516 // Use TopDownTraversal for compile time reasons.
2517 (void)applyPatternsGreedily(op, std::move(patterns),
2518 GreedyRewriteConfig().setUseTopDownTraversal());
2519 }
2520};
2521
2522} // namespace
return success()
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:878
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:232
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
This transformation is intended to be used with a top-down traversal (from producer to consumer).
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:239
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition Transforms.h:656
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.