MLIR 22.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Support/LLVM.h"
29#include <optional>
30#include <utility>
31
32namespace mlir {
33#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34#include "mlir/Dialect/Linalg/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::linalg;
39
40//===---------------------------------------------------------------------===//
41// Methods and patterns that fuse elementwise `linalg.generic` operations.
42//===---------------------------------------------------------------------===//
43
44/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45/// the `producer` to use in the fused operation given the indexing map of the
46/// result of the producer in the consumer.
48 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49 AffineMap fusedConsumerArgIndexMap) {
50 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51 // from consumer loop -> consumer arg tensor index/producer result tensor
52 // index. The fused loop is same as the consumer loop. For each producer arg
53 // the indexing map to be computed is a map from consumer loop -> producer
54 // arg tensor index.
55 // producerResultIndexMap is a map from producer loop -> tensor index.
56 // Compute the inverse to get map from tensor index -> producer loop.
57 // The inverse is a map from producer result tensor index -> producer loop.
58 AffineMap invProducerResultIndexMap =
59 inversePermutation(producerResultIndexMap);
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
62
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64 // argMap is a map from producer loop -> producer arg tensor index.
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66
67 // Compose argMap with invProducerResultIndexMap to get a map from
68 // producer result tensor index -> producer arg tensor index.
69 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70
71 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72 // consumer loop/ fused loop -> producer arg tensor index.
73 return t1.compose(fusedConsumerArgIndexMap);
74}
75
76// Checks if the given operand can be dropped, and the remaining operands
77// of the fused producer & consumer after the fusion can still compute the
78// bounds of the op.
80 GenericOp producer, GenericOp consumer,
81 ArrayRef<OpOperand *> opOperandsToIgnore) {
82 SmallVector<AffineMap> indexingMaps;
83
84 SmallVector<GenericOp> ops = {producer, consumer};
85 for (auto &op : ops) {
86 for (auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 continue;
89 }
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91 }
92 }
93 if (indexingMaps.empty()) {
94 // If there are no indexing maps, the operand can only be dropped
95 // if neither op has loops.
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97 }
98
99 // The concatanation of the remained indexing maps must be invertible, so
100 // the bounds of the op can be still computed after dropping the selected
101 // operand. inversePermutation returns an empty AffineMap in case the
102 // concatanated indexing maps are not invertible.
104 indexingMaps, producer.getContext())) != AffineMap();
105}
106
107/// Returns a set of indices of the producer's results which would
108/// be preserved after the fusion.
109/// * There is a chance that the implementation of the transformation does not
110/// agree with the result of this method. This function gives a prediction based
111/// on an optimized fusion.
113 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
115 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116
117 // The fusedOperand will be removed during the fusion
118 opOperandsToIgnore.emplace_back(fusedOperand);
119
120 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127 return user != consumer.getOperation();
128 })) {
129 preservedProducerResults.insert(producerResult.index());
130
131 // In case the operand can't be dropped
132 (void)opOperandsToIgnore.pop_back_val();
133 }
134 }
135 return preservedProducerResults;
136}
137
138/// Conditions for elementwise fusion of generic operations.
140 if (!fusedOperand)
141 return false;
142
143 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145
146 // Check producer and consumer are generic ops.
147 if (!producer || !consumer)
148 return false;
149
150 // Consumer can have mixed semantics, just check operand itself has tensor
151 // type. Producer must have full tensor semantics to avoid potential
152 // aliasing between producer and consumer memrefs.
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->get().getType()))
155 return false;
156
157 // Verify that
158 // - the producer has all "parallel" iterator type.
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
160 return false;
161
162 // Only allow fusing the producer of an input operand for now.
163 // TODO: allow fusing the producer of an output operand.
164 if (!consumer.isDpsInput(fusedOperand))
165 return false;
166
167 // Get the consumer index map. The number of results of the consumer index
168 // map must match the number of loops of the producer.
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171 return false;
172
173 // Finally the index_map for the result must be invertible. For now just
174 // verify it is a permutation.
175 auto producerResult = cast<OpResult>(fusedOperand->get());
176 AffineMap producerResultIndexMap =
177 producer.getIndexingMapMatchingResult(producerResult);
178 if (!producerResultIndexMap.isPermutation())
179 return false;
180
181 // Ensure that the fusion does not remove size information required to
182 // get the loop bounds. For non-reduction generics, this is trivially the
183 // case due to the output operand. For reductions, we need to check that after
184 // the fusion, each loop dimension has at least one input that defines it.
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(), false);
187
188 auto addToCoveredDims = [&](AffineMap map) {
189 for (auto result : map.getResults())
190 if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
191 coveredDims[dimExpr.getPosition()] = true;
192 };
193
194 for (auto pair :
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->get())
198 continue;
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
201 }
202
203 for (OpOperand *operand : producer.getDpsInputOperands()) {
204 AffineMap newIndexingMap =
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
208 }
209 if (!coveredDims.all())
210 return false;
211 }
212
213 return true;
214}
215
216/// Generate the region of the fused tensor operation. The region of the fused
217/// op must be empty.
219 RewriterBase &rewriter, GenericOp fusedOp,
220 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
223 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
224 // Build the region of the fused op.
225 Block &producerBlock = producer->getRegion(0).front();
226 Block &consumerBlock = consumer->getRegion(0).front();
227 OpBuilder::InsertionGuard guard(rewriter);
228 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
229 IRMapping mapper;
230
231 // 2. Add an index operation for every fused loop dimension and use the
232 // `consumerToProducerLoopsMap` to map the producer indices.
233 if (producer.hasIndexSemantics()) {
234 // Add an index operation for every fused loop dimension.
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
236 SmallVector<Value> fusedIndices;
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
241 });
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.map(indexOp.getResult(), newIndex);
248 }
249 }
250 // TODO: allow fusing the producer of an output operand.
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
253 // 3. Consumer input operands up to consumerIdx (exclusive).
254 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255 fusedOperand->getOperandNumber())) // input assumption.
256 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257
258 // Replacing consumerIdx requires getting the cloned, yielded, value from
259 // the (cloned) producer block. This happens in step 9.
260
261 // 4. Splice in producer's input operands.
262 for (BlockArgument bbArg :
263 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265
266 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267 for (BlockArgument bbArg :
268 consumerBlock.getArguments()
269 .take_front(consumer.getNumDpsInputs())
270 .drop_front(fusedOperand->getOperandNumber() + 1))
271 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272
273 // 6. All of the producer's output operands
274 for (const auto &bbArg : llvm::enumerate(
275 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
277 continue;
278 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
280 }
281
282 // 7. All of consumer's output operands.
283 for (BlockArgument bbArg :
284 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286
287 // 8. Clone all producer operations except for the yield and index operations
288 // to the fused operation.
289 for (auto &op : producerBlock.without_terminator()) {
290 if (!isa<IndexOp>(op))
291 rewriter.clone(op, mapper);
292 }
293 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294 // forward the yield operand.
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->get()).getResultNumber();
299 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300
301 // Sanity checks, if replacement is not already in the mapper then it must be
302 // produced outside.
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
307 else
308 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
310 }
311 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
313 // 10. Clone operations from the consumer to the fused op.
314 for (auto &op : consumerBlock.without_terminator())
315 rewriter.clone(op, mapper);
316
317 // 11. Include the final yield (which is the remapped values for all the
318 // yield)
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320 SmallVector<Value> fusedYieldValues;
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
327 mapper.lookupOrDefault(producerYieldVal.value()));
328 }
329 for (auto consumerYieldVal : consumerYieldOp.getOperands())
330 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
332
333 // Sanity checks.
334 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335 "Ill-formed GenericOp region");
336}
337
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340 OpOperand *fusedOperand) {
341 assert(areElementwiseOpsFusable(fusedOperand) &&
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346 // TODO: allow fusing the producer of an output operand.
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
349 /// Find the results of the producer that have uses outside of the consumer,
350 /// after the fusion.
351 llvm::SmallDenseSet<int> preservedProducerResults =
353 fusedOperand);
354
355 // Compute the fused operands list and indexing maps.
356 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357 SmallVector<Type> fusedResultTypes;
358 SmallVector<AffineMap> fusedIndexMaps;
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
367 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371 return operand == fusedOperand;
372 });
373 assert(it != consumerInputs.end() && "expected to find the consumer operand");
374 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377 }
378 // 4. Splice in producer's input operands/maps.
379 AffineMap producerResultIndexMap =
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
383 // Compute indexing maps for the producer args in the fused operation.
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
388 }
389 // 5. Remaining consumer's input operands/maps (drop past index
390 // `consumerIdx`).
391 for (OpOperand *opOperand :
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395 }
396
397 // 6. Collect all of the producer outputs.
398 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
400 continue;
401
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
408 }
409
410 // 7. All of consumer's output operands (skip operands: added by the builder).
411 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
417 }
418
419 // Generate the fused op.
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423 consumer.getIteratorTypes(),
424 /*doc=*/nullptr,
425 /*library_call=*/nullptr);
426 if (!fusedOp.getShapesToLoopsMap()) {
427 // Fused op has invalid indexing maps. Typically this means something is off
428 // in the input, but going ahead here would result in verification errors.
429 // So cleanup and abort.
430 rewriter.eraseOp(fusedOp);
431 return rewriter.notifyMatchFailure(
432 fusedOp, "fused op failed loop bound computation check");
433 }
434
435 // Construct an AffineMap from consumer loops to producer loops.
436 // consumer loop -> tensor index
437 AffineMap consumerResultIndexMap =
438 consumer.getMatchingIndexingMap(fusedOperand);
439 // tensor index -> producer loop
440 AffineMap invProducerResultIndexMap =
441 inversePermutation(producerResultIndexMap);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
444 // consumer loop -> producer loop
445 AffineMap consumerToProducerLoopsMap =
446 invProducerResultIndexMap.compose(consumerResultIndexMap);
447
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
452 result.fusedOp = fusedOp;
453 int resultNum = 0;
454 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459 return result;
460}
461
462namespace {
463/// Patterns to fuse a generic op, with the producer of its operands.
464class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465public:
466 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467 PatternBenefit benefit = 1)
468 : OpRewritePattern<GenericOp>(context, benefit),
469 controlFn(std::move(fun)) {}
470
471 LogicalResult matchAndRewrite(GenericOp genericOp,
472 PatternRewriter &rewriter) const override {
473 // Find the first operand that is defined by another generic op on tensors.
474 for (OpOperand &opOperand : genericOp->getOpOperands()) {
475 if (!areElementwiseOpsFusable(&opOperand))
476 continue;
477 if (!controlFn(&opOperand))
478 continue;
479
480 Operation *producer = opOperand.get().getDefiningOp();
481
482 // Find the producer of the operand.
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
484 fuseElementwiseOps(rewriter, &opOperand);
485 if (failed(fusionResult))
486 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487
488 // Perform the fusion.
489 for (auto [origVal, replacement] : fusionResult->replacements) {
490 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491 // Only replace consumer uses.
492 return use.get().getDefiningOp() != producer;
493 });
494 }
495 rewriter.eraseOp(genericOp);
496 return success();
497 }
498 return failure();
499 }
500
501private:
502 ControlFusionFn controlFn;
503};
504} // namespace
505
506//===---------------------------------------------------------------------===//
507// Methods and patterns that fuse reshape ops with elementwise operations by
508// expanding the dimensionality of the elementwise operations.
509//===---------------------------------------------------------------------===//
510
511/// Conditions for folding a structured linalg operation with a reshape op by
512/// expanding the iteration space dimensionality for tensor operations. These
513/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514/// the following fusion pattern.
515///
516/// Consider
517///
518/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520/// affine_map<(d0, d1, d2) -> (d1, d2)>,
521/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524///
525/// The reshape can be folded into the `linalgOp` if its loop dimensionality
526/// is increased to match the result (operand) of the tensor.expand_shape.
527/// The indexing_map of the fused tensor in the `linalgOp` and the
528/// reassociation map helps compute the indexing maps of the modified op.
529/// For the above example, based on the reassociation map it
530/// can be concluded that
531///
532/// - The loop used to access the first dimension of the fused tensor is split
533/// into two.
534/// - The loop used to access the second dimension of the fused tensor is kept
535/// as is.
536/// - The loop used to access the third dimension of the fused tensor is split
537/// into three.
538///
539/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540/// op, then
541///
542/// d0 -> e0, e1
543/// d1 -> e2, e3, e4
544/// d2 -> e5
545///
546/// substituting this, the structured op can be rewritten as
547///
548/// %d = linalg.generic ins(%0, %1 : )
549/// indexing_maps =
550/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553///
554/// Since operands to the linalg generic are now 5D, reshapes can be introduced
555/// to make it consistent
556///
557/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561///
562/// The added reshapes are again expanding patterns, so they will get fused
563/// with its producers if possible.
564static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565 OpOperand *fusableOpOperand) {
566 // Is fusable only if:
567 // - All the indexing maps for operands and results are projected
568 // permutations.
569 // - The fused tensor is not a scalar.
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 [](Attribute attr) {
576 return cast<AffineMapAttr>(attr)
577 .getValue()
578 .isProjectedPermutation();
579 }) &&
580 operandMap.getNumResults() > 0;
581}
582
583namespace {
584/// Information needed to expand a generic operation to fold the reshape with
585/// it.
586class ExpansionInfo {
587public:
588 // Computes the mapping from original dimensions of the op to the dimensions
589 // of the expanded op given the `indexingMap` of the fused operand/result of
590 // the generic op, the `reassocationMaps` of the reshape op and the shape of
591 // the expanded op.
592 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593 ArrayRef<AffineMap> reassociationMaps,
594 ArrayRef<OpFoldResult> expandedShape,
595 PatternRewriter &rewriter);
596 unsigned getOrigOpNumDims() const { return reassociation.size(); }
597 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598 ReassociationIndicesRef getExpandedDims(unsigned i) const {
599 return reassociation[i];
600 }
601 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602 return expandedShapeMap[i];
603 }
604 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605
606private:
607 /// Reassociation from the dimensions in the original operation to the
608 /// dimension of the expanded operation.
609 SmallVector<ReassociationIndices> reassociation;
610 /// Mapping from extent of loops in the original operation, to the extent of
611 /// loops in the expanded operation.
612 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613 /// Extent of the loop in the original operation.
614 SmallVector<OpFoldResult> originalLoopExtent;
615 unsigned expandedOpNumDims;
616};
617} // namespace
618
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620 OpOperand *fusableOpOperand,
621 ArrayRef<AffineMap> reassociationMaps,
622 ArrayRef<OpFoldResult> expandedShape,
623 PatternRewriter &rewriter) {
624 if (reassociationMaps.empty())
625 return failure();
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627
628 OpBuilder::InsertionGuard g(rewriter);
629 rewriter.setInsertionPoint(linalgOp);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](Range r) { return r.size; });
633
634 reassociation.clear();
635 expandedShapeMap.clear();
636 // Compute the number of dimension in the expanded op that correspond to each
637 // dimension of the original op.
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639 expandedShapeMap.resize(fusedIndexMap.getNumDims());
640 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643 numExpandedDims[pos] = foldedDims.getNumResults();
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
647 }
648 // The remaining dimensions remain the same.
649 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
652
653 // Compute reassociation map from the original op to the expanded op.
654 unsigned sum = 0;
655 reassociation.reserve(fusedIndexMap.getNumDims());
656 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
660 }
661 expandedOpNumDims = sum;
662 return success();
663}
664
665/// Return the indexing map to use in the expanded op for a given the
666/// `indexingMap` of the original operation.
667static AffineMap
669 const ExpansionInfo &expansionInfo) {
671 for (AffineExpr expr : indexingMap.getResults()) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
676 }));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
678 }
679 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680 indexingMap.getNumSymbols(), newExprs,
681 builder.getContext());
682}
683
684/// Return the shape and type of the operand/result to use in the expanded op
685/// given the type in the original op.
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688 const ExpansionInfo &expansionInfo) {
689 SmallVector<OpFoldResult> expandedShape;
690 for (AffineExpr expr : indexingMap.getResults()) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692 ArrayRef<OpFoldResult> dimExpansion =
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695 }
696 SmallVector<int64_t> expandedStaticShape;
697 std::tie(expandedStaticShape, std::ignore) =
698 decomposeMixedValues(expandedShape);
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
701}
702
703/// Returns the reassociation maps to use in the `tensor.expand_shape`
704/// operation to convert the operands of the original operation to operands of
705/// the expanded operation. The same method is used to compute the
706/// `tensor.collapse_shape` used to collapse the result of the expanded
707/// op to get the value that can replace all uses of the results of the original
708/// op.
709static SmallVector<ReassociationIndices>
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
714 for (AffineExpr expr : indexingMap.getResults()) {
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
721 }
722 return reassociation;
723}
724
725/// Update the body of an expanded linalg operation having index semantics. The
726/// indices of the original operation need to be recovered by linearizing the
727/// indices of the correspoding dimensions of the expanded operation. For now it
728/// is assumed that the shapes of the expanded operation needed for
729/// linearization are static.
731 Location loc, Region &fusedRegion,
732 const ExpansionInfo &expansionInfo) {
733 // Replace the original indices by the linearization of the expanded indices.
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736 ArrayRef<int64_t> expandedDims =
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() && "expected valid expansion info");
739
740 // Skip index operations that are not affected by the expansion.
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
743 continue;
744
745 // Linearize the expanded indices of the original index dimension.
746 OpBuilder::InsertionGuard guard(rewriter);
747 rewriter.setInsertionPointAfter(indexOp);
748 ArrayRef<OpFoldResult> expandedDimsShape =
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750 SmallVector<Value> expandedIndices;
751 expandedIndices.reserve(expandedDims.size() - 1);
752 llvm::transform(
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
755 OpFoldResult newIndex =
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
759 AffineExpr idx, acc, shape;
760 bindDims(rewriter.getContext(), idx, acc);
761 bindSymbols(rewriter.getContext(), shape);
763 rewriter, indexOp.getLoc(), idx + acc * shape,
764 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765 }
766 Value newIndexVal =
767 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768 rewriter.replaceOp(indexOp, newIndexVal);
769 }
770}
771
772// Create an expanded transpose op.
773// the reassociation map is already permuted hence we inverse permute and then
774// flatten it. Then we inverse permute it again to get the final expanded
775// transpose permutation. For example,
776//
777// permutation = [2, 0, 1]
778// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779//
780// inverse permutation = [1, 2, 0]
781// applied to reassocation_map and then flattened becomes
782// flatened permutation = [2, 3, 4, 5, 0, 1]
783// final permuation is the inverse of the flattened permutation.
784//
785// Becomes
786//
787// permutation=[4, 5, 0, 1, 2, 3]
788
790 TransposeOp transposeOp,
791 Value expandedInput, Value output,
792 ExpansionInfo &expansionInfo) {
793 SmallVector<int64_t> newPerm;
794 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
798 }
799 }
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
801 output, invertPermutationVector(newPerm));
802}
803
804// Create an expanded generic op.
806 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809 // The iterator types of the expanded op are all parallel.
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812
813 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[j] = type;
816
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
820
821 Region &fusedRegion = fused->getRegion(0);
822 Region &originalRegion = linalgOp->getRegion(0);
823 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824
825 // Update the index accesses after the expansion.
826 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827 expansionInfo);
828
829 return fused;
830}
831
832// Create an expanded fused op that retains the name for certain ops
833// such as fill, copy and transpose and produce a generic op for
834// rest of linalg ops.
835static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836 TypeRange resultTypes,
837 ArrayRef<Value> expandedOpOperands,
838 ArrayRef<Value> outputs,
839 ArrayRef<AffineMap> expandedOpIndexingMaps,
840 ExpansionInfo &expansionInfo) {
841
842 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
844 return createExpandedTransposeOp(rewriter, transposeOp,
845 expandedOpOperands[0], outputs[0],
846 expansionInfo);
847 })
848 .Case<FillOp, CopyOp>([&](Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
853 })
854 .Default([&](Operation *op) {
855 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
858 });
859}
860
861/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863/// that those conditions have been satisfied.
864static std::optional<SmallVector<Value>>
865fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866 OpOperand *fusableOpOperand,
867 PatternRewriter &rewriter) {
868 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869 "preconditions for fuse operation failed");
870
871 Location loc = linalgOp.getLoc();
872 SmallVector<OpFoldResult> expandedShape;
873 SmallVector<AffineMap, 4> reassociationIndices;
874 Value src;
875 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876 // Try to move the dynamic dimensions in output shape before the `linalgOp`
877 // to maintain SSA validity
878 if (failed(moveValueDefinitions(
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880 return std::nullopt;
881
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
885 } else {
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
888 return std::nullopt;
889
890 expandedShape = tensor::getMixedSizes(
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
894 }
895
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
899 rewriter)))
900 return std::nullopt;
901
902 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905 }));
906
907 // Set insertion point to the generic op.
908 OpBuilder::InsertionGuard g(rewriter);
909 rewriter.setInsertionPoint(linalgOp);
910
911 SmallVector<Value> expandedOpOperands;
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
916 continue;
917 }
918 if (auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921 SmallVector<OpFoldResult> expandedOperandShape;
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
924 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925 if (expandedOperandType != opOperand->get().getType()) {
926 // Reshape the operand to get the right type.
928 getReassociationForExpansion(indexingMap, expansionInfo);
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](const Twine &msg) {
931 return rewriter.notifyMatchFailure(linalgOp, msg);
932 },
933 opOperandType.getShape(), expandedOperandType.getShape(),
934 reassociation,
935 /*isExpandingReshape=*/true)))
936 return std::nullopt;
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
940 continue;
941 }
942 }
943 expandedOpOperands.push_back(opOperand->get());
944 }
945
946 SmallVector<Value> outputs;
947 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950 SmallVector<OpFoldResult> expandedOutputShape;
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
953 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954 if (expandedOutputType != opOperand.get().getType()) {
956 getReassociationForExpansion(indexingMap, expansionInfo);
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](const Twine &msg) {
959 return rewriter.notifyMatchFailure(linalgOp, msg);
960 },
961 opOperandType.getShape(), expandedOutputType.getShape(),
962 reassociation,
963 /*isExpandingReshape=*/true)))
964 return std::nullopt;
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
968 } else {
969 outputs.push_back(opOperand.get());
970 }
971 }
972
973 TypeRange resultTypes = ValueRange(outputs).getTypes();
974 Operation *fusedOp =
975 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976 outputs, expandedOpIndexingMaps, expansionInfo);
977 // Reshape the result values to their original shape if this is a collapsing
978 // reshape folded into its consumer.
979 SmallVector<Value> resultVals;
980 for (OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
987 expansionInfo);
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.getType(),
990 fusedOp->getResult(resultNumber), reassociation));
991 } else {
992 resultVals.push_back(fusedOp->getResult(resultNumber));
993 }
994 }
995 // Assuming a single result.
996 return resultVals;
997}
998
999namespace {
1000
1001/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003/// in the consumer is expanded.
1004class FoldWithProducerReshapeOpByExpansion
1005 : public OpInterfaceRewritePattern<LinalgOp> {
1006public:
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008 ControlFusionFn foldReshapes,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1012
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter) const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018 if (!reshapeOp)
1019 continue;
1020 // Fold only if
1021 // - The tensor reshape op is folding.
1022 // - All constraints of fusing with reshape by expansion are met.
1023 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024 (!controlFoldingReshapes(opOperand)))
1025 continue;
1026
1027 std::optional<SmallVector<Value>> replacementValues =
1028 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029 if (!replacementValues)
1030 return failure();
1031 rewriter.replaceOp(linalgOp, *replacementValues);
1032 return success();
1033 }
1034 return failure();
1035 }
1036
1037private:
1038 ControlFusionFn controlFoldingReshapes;
1039};
1040
1041/// Carries information about a padded dimension.
1042struct PadDimInfo {
1043 // The resulting shape after padding each dimension.
1044 SmallVector<int64_t> paddedShape;
1045
1046 // Low and high padding amounts for each dimension.
1047 SmallVector<OpFoldResult> lowPad;
1048 SmallVector<OpFoldResult> highPad;
1049};
1050
1051/// Computes the expanded padding information for the given pad operation based
1052/// on the provided expanded shape and reassociation indices. Returns a list of
1053/// PadDimInfo containing the low and high padding amounts and the padded
1054/// size for each dimension, or failure if the expansion is not possible.
1055static FailureOr<PadDimInfo>
1056computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1057 ArrayRef<ReassociationIndices> reassociations,
1058 PatternRewriter &rewriter) {
1059 // If the padding value depends on the index values of the pad operation,
1060 // then it may not be valid to expand the dimensions, since it will change
1061 // the index values on which the padding value depends. This is not currently
1062 // supported by the pad expansion patterns, but it could be implemented
1063 // similarly to the expansion of linalg.generic ops with linalg.index ops in
1064 // the body, as is done in `updateExpandedGenericOpRegion`.
1065 if (!padOp.getConstantPaddingValue())
1066 return failure();
1067
1068 // Expanded dimensions cannot have padding because the resulting padding may
1069 // not be representable by a tensor.pad op. There are some special cases where
1070 // it is possible (like expanding unit dims), but supporting these cases is
1071 // NYI, so disallow it for now.
1072 ArrayRef<int64_t> low = padOp.getStaticLow();
1073 ArrayRef<int64_t> high = padOp.getStaticHigh();
1074 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1075 if (reInd.size() != 1 && (l != 0 || h != 0))
1076 return failure();
1077 }
1078
1079 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1080 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1081 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1082 PadDimInfo padDimInfo;
1083 padDimInfo.paddedShape.assign(expandedShape);
1084 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1085 padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1086 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1087 if (reInd.size() == 1) {
1088 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1089 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1090 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1091 }
1092 }
1093
1094 return padDimInfo;
1095}
1096
1097class FoldPadWithProducerReshapeOpByExpansion
1098 : public OpRewritePattern<tensor::PadOp> {
1099public:
1100 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1101 ControlFusionFn foldReshapes,
1102 PatternBenefit benefit = 1)
1103 : OpRewritePattern<tensor::PadOp>(context, benefit),
1104 controlFoldingReshapes(std::move(foldReshapes)) {}
1105
1106 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1107 PatternRewriter &rewriter) const override {
1108 tensor::CollapseShapeOp reshapeOp =
1109 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1110 if (!reshapeOp)
1111 return failure();
1112
1113 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1114 return rewriter.notifyMatchFailure(padOp,
1115 "fusion blocked by control function");
1116 }
1117
1118 RankedTensorType expandedType = reshapeOp.getSrcType();
1119 SmallVector<ReassociationIndices> reassociations =
1120 reshapeOp.getReassociationIndices();
1121 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1122 padOp, expandedType.getShape(), reassociations, rewriter);
1123 if (failed(maybeExpandedPadding))
1124 return failure();
1125 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1126
1127 Location loc = padOp->getLoc();
1128 RankedTensorType expandedPaddedType =
1129 padOp.getResultType().clone(expandedPadding.paddedShape);
1130
1131 auto newPadOp = tensor::PadOp::create(
1132 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1133 expandedPadding.lowPad, expandedPadding.highPad,
1134 padOp.getConstantPaddingValue(), padOp.getNofold());
1135
1136 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1137 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1138
1139 return success();
1140 }
1141
1142private:
1143 ControlFusionFn controlFoldingReshapes;
1144};
1145
1146class FoldReshapeWithProducerPadOpByExpansion
1147 : public OpRewritePattern<tensor::ExpandShapeOp> {
1148public:
1149 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1150 ControlFusionFn foldReshapes,
1151 PatternBenefit benefit = 1)
1152 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1153 controlFoldingReshapes(std::move(foldReshapes)) {}
1154
1155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1156 PatternRewriter &rewriter) const override {
1157 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1158 if (!padOp)
1159 return failure();
1160
1161 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1162 return rewriter.notifyMatchFailure(expandOp,
1163 "fusion blocked by control function");
1164 }
1165
1166 RankedTensorType expandedType = expandOp.getResultType();
1167 SmallVector<ReassociationIndices> reassociations =
1168 expandOp.getReassociationIndices();
1169 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1170 padOp, expandedType.getShape(), reassociations, rewriter);
1171 if (failed(maybeExpandedPadding))
1172 return failure();
1173 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1174
1175 Location loc = expandOp->getLoc();
1176 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1177 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1178 rewriter.setInsertionPointAfterValue(padOp.getSource());
1179 SmallVector<OpFoldResult> padSrcSizes =
1180 tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1181 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1182 // We know that any reassociation with multiple dims is not padded because
1183 // of the requirements of computeExpandedPadding.
1184 if (reInd.size() == 1) {
1185 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1186 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1187 }
1188 }
1189 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1190 auto newExpandOp = tensor::ExpandShapeOp::create(
1191 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1192 newExpandedSizes);
1193 RankedTensorType expandedPaddedType =
1194 padOp.getResultType().clone(expandedPadding.paddedShape);
1195 rewriter.setInsertionPoint(expandOp);
1196 auto newPadOp = tensor::PadOp::create(
1197 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1198 expandedPadding.lowPad, expandedPadding.highPad,
1199 padOp.getConstantPaddingValue(), padOp.getNofold());
1200
1201 rewriter.replaceOp(expandOp, newPadOp.getResult());
1202
1203 return success();
1204 }
1205
1206private:
1207 ControlFusionFn controlFoldingReshapes;
1208};
1209
1210/// Pattern to fold a tensor.expand_shape op with its producer generic op
1211/// by expanding the dimensionality of the loop in the producer op.
1212struct FoldReshapeWithGenericOpByExpansion
1213 : public OpRewritePattern<tensor::ExpandShapeOp> {
1214
1215 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1216 ControlFusionFn foldReshapes,
1217 PatternBenefit benefit = 1)
1218 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1219 controlFoldingReshapes(std::move(foldReshapes)) {}
1220
1221 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1222 PatternRewriter &rewriter) const override {
1223 // Fold only if all constraints of fusing with reshape by expansion are met.
1224 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1225 if (!producerResult) {
1226 return rewriter.notifyMatchFailure(reshapeOp,
1227 "source not produced by an operation");
1228 }
1229
1230 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1231 if (!producer) {
1232 return rewriter.notifyMatchFailure(reshapeOp,
1233 "producer not a generic op");
1234 }
1235
1237 producer,
1238 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1239 return rewriter.notifyMatchFailure(
1240 reshapeOp, "failed preconditions of fusion with producer generic op");
1241 }
1242
1243 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1244 return rewriter.notifyMatchFailure(reshapeOp,
1245 "fusion blocked by control function");
1246 }
1247
1248 std::optional<SmallVector<Value>> replacementValues =
1250 producer, reshapeOp,
1251 producer.getDpsInitOperand(producerResult.getResultNumber()),
1252 rewriter);
1253 if (!replacementValues) {
1254 return rewriter.notifyMatchFailure(reshapeOp,
1255 "fusion by expansion failed");
1256 }
1257
1258 // Find the replacement for the reshape op. Since the replacements have the
1259 // same type as the returns of the original generic op, the consumer reshape
1260 // op can be replaced by the source of the collapse_shape op that defines
1261 // the replacement.
1262 Value reshapeReplacement =
1263 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1264 .getResultNumber()];
1265 if (auto collapseOp =
1266 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1267 reshapeReplacement = collapseOp.getSrc();
1268 }
1269 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1270 rewriter.replaceOp(producer, *replacementValues);
1271 return success();
1272 }
1273
1274private:
1275 ControlFusionFn controlFoldingReshapes;
1276};
1277} // namespace
1278
1279//===---------------------------------------------------------------------===//
1280// Methods and patterns to fuse reshape with linalg.generic operations by
1281// contraction of dimensions.
1282//===---------------------------------------------------------------------===//
1283
1284/// For a given list of indices in the range of the `indexingMap` that are
1285/// folded, return the indices of the corresponding domain. Return
1286/// `std::nullopt` on failure. Ensures that all the elements of the returned
1287/// reassociation are distinct.
1290 ReassociationIndicesRef rangeReassociation) {
1291 assert(indexingMap.isProjectedPermutation() &&
1292 "expected projected permutation");
1293
1294 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1295 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1296 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1297 }));
1298 // The projected permutation semantics ensures that there is no repetition of
1299 // the domain indices.
1300 return domainReassociation;
1301}
1302
1303/// For a given `dimSequence`, check if the sequence is conserved in the
1304/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1305/// Non-existence of the sequence returns true as well.
1307 ReassociationIndicesRef dimSequence) {
1308 assert(!dimSequence.empty() &&
1309 "expected non-empty list for dimension sequence");
1310 assert(indexingMap.isProjectedPermutation() &&
1311 "expected indexing map to be projected permutation");
1312
1313 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1314 sequenceElements.insert_range(dimSequence);
1315
1316 unsigned dimSequenceStart = dimSequence[0];
1317 for (const auto &expr : enumerate(indexingMap.getResults())) {
1318 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1319 // 1. Check if this start of the sequence.
1320 if (dimInMapStart == dimSequenceStart) {
1321 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1322 return false;
1323 // 1a. Check if sequence is preserved.
1324 for (const auto &dimInSequence : enumerate(dimSequence)) {
1325 unsigned dimInMap =
1326 cast<AffineDimExpr>(
1327 indexingMap.getResult(expr.index() + dimInSequence.index()))
1328 .getPosition();
1329 if (dimInMap != dimInSequence.value())
1330 return false;
1331 }
1332 // Found the sequence. Projected permutation
1333 // enforces that all AffineDimExprs in the result are unique, so no
1334 // further checks are needed.
1335 return true;
1336 }
1337 // 2. If position in the expr (which is of type AffineDimExpr) is part
1338 // of sequence, return false here. This implies the entire sequence does not
1339 // exist in the indexing map.
1340 if (sequenceElements.count(dimInMapStart))
1341 return false;
1342 }
1343 // 3. No element of sequence found. Return true.
1344 return true;
1345}
1346
1349 return llvm::all_of(maps, [&](AffineMap map) {
1350 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1351 return isDimSequencePreserved(map, dimSequence);
1352 });
1353 });
1354}
1355
1356// Return the list of dimensions of the iteration domain that can be
1357// collapsed to allow for fusion with the a producer that is an expand_shape
1358// operation. If all dimensions created by expansion can be collapsed in the
1359// iteration space then the reshape is defunct.
1360//
1361// Example:
1362//
1363// ```mlir
1364// #map = affine_map<(d0, d1) -> (d0, d1)>
1365// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1366// %2 = tensor.empty [..] : tensor<?x4xf32>
1367// %3 = linalg.generic {
1368// indexing_maps = [#map, #map],
1369// iterator_types = ["parallel" ,"parallel"]}
1370// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1371// ```
1372//
1373// can be fused by collapsing the dimensions of the iteration space.
1374//
1375// ```mlir
1376// #map = affine_map<(d0) -> (d0)>
1377// %2 = tensor.empty [..] : tensor<?xf32>
1378// %3 = linalg.generic {
1379// indexing_maps = [#map, #map],
1380// iterator_types = ["parallel"]}
1381// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1382// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1383// ```
1384//
1385// In the following example,
1386//
1387// ```mlir
1388// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1389// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1390// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1391// %2 = tensor.empty [..] : tensor<4x?xf32>
1392// %2 = linalg.generic {
1393// indexing_maps = [#map0, #map1],
1394// iterator_types = ["parallel" ,"parallel"]}
1395// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1396// ```
1397//
1398// the reshape cannot be fused with the generic op by collapsing the op
1399// dimensions since the indexing maps will have to contain mods and divs
1400// to preserve the accesses pattern. When no dimensions of the iteration
1401// space are collapsable and empty vector is returned.
1403getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1404 ArrayRef<ReassociationIndices> reassociation) {
1405 // Some basic checks for this fusion to be valid.
1406 if (!genericOp.hasPureTensorSemantics())
1407 return {};
1408
1409 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1410 return map.isProjectedPermutation();
1411 })) {
1412 return {};
1413 }
1414
1415 // Compute all the loops with the reduction iterator types.
1416 SmallVector<unsigned> reductionDims;
1417 genericOp.getReductionDims(reductionDims);
1418
1419 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1420 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1421 auto iteratorTypes = genericOp.getIteratorTypesArray();
1422 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1423 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1424 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1425
1426 // Ignore dims that are not folded.
1427 if (foldedRangeDims.size() == 1)
1428 continue;
1429
1430 ReassociationIndices foldedIterationSpaceDims =
1431 getDomainReassociation(indexingMap, foldedRangeDims);
1432
1433 // Check that the folded iteration dims do not contain already processed
1434 // dims.
1435 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1436 return processedIterationDims.count(dim);
1437 }))
1438 continue;
1439
1440 // Check that all folded iterator types are all parallel or all reductions.
1441 utils::IteratorType startIteratorType =
1442 iteratorTypes[foldedIterationSpaceDims[0]];
1443 if (!isParallelIterator(startIteratorType) &&
1444 !isReductionIterator(startIteratorType))
1445 continue;
1446 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1447 return iteratorTypes[dim] != startIteratorType;
1448 }))
1449 continue;
1450
1451 // If the folded dimensions correspond to a "reduction" iterator type,
1452 // the folded dimensions need to be "in-order". Strictly speaking this is
1453 // not necessary, for reductions that are associative and commutative, but
1454 // using a more strict definition of reduction for now.
1455 if (isReductionIterator(startIteratorType)) {
1456 bool isContiguous = false;
1457 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1458 // Move window in `reductionDims` to start of the folded iteration dims.
1459 if (startDim.value() != foldedIterationSpaceDims[0])
1460 continue;
1461 // If sizes doesnt match, trivial not contiguous. This condition should
1462 // not be hit.
1463 if (startDim.index() + foldedIterationSpaceDims.size() >
1464 reductionDims.size())
1465 break;
1466 // Check that the contiguity is maintained.
1467 isContiguous = true;
1468 for (const auto &foldedDim :
1469 llvm::enumerate(foldedIterationSpaceDims)) {
1470 if (reductionDims[foldedDim.index() + startDim.index()] !=
1471 foldedDim.value()) {
1472 isContiguous = false;
1473 break;
1474 }
1475 }
1476 break;
1477 }
1478 if (!isContiguous)
1479 continue;
1480 }
1481
1482 // Check that the sequence is preserved in all indexing maps.
1483 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1484 [&](AffineMap indexingMap) {
1485 return !isDimSequencePreserved(indexingMap,
1486 foldedIterationSpaceDims);
1487 }))
1488 continue;
1489
1490 processedIterationDims.insert_range(foldedIterationSpaceDims);
1491 iterationSpaceReassociation.emplace_back(
1492 std::move(foldedIterationSpaceDims));
1493 }
1494
1495 return iterationSpaceReassociation;
1496}
1497
1498/// Helper class to carry state while collapsing the `linalg.generic` op.
1499namespace {
1500class CollapsingInfo {
1501public:
1502 LogicalResult initialize(unsigned origNumLoops,
1503 ArrayRef<ReassociationIndices> foldedIterationDims) {
1504 llvm::SmallDenseSet<int64_t, 4> processedDims;
1505 // Find all the dims that are folded.
1506 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1507 if (foldedIterationDim.empty())
1508 continue;
1509 // If the folded dims contain dims already folded, that's illegal
1510 // specification. Repetition within a list is also illegal.
1511 for (auto dim : foldedIterationDim) {
1512 if (dim >= origNumLoops)
1513 return failure();
1514 if (processedDims.count(dim))
1515 return failure();
1516 processedDims.insert(dim);
1517 }
1518 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1519 foldedIterationDim.end());
1520 }
1521 if (processedDims.size() > origNumLoops)
1522 return failure();
1523
1524 // Add all the preserved dims of the original op as single
1525 // elements to `collapsedOpToOrigOpIterationDim`.
1526 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1527 if (processedDims.count(dim))
1528 continue;
1529 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1530 }
1531
1532 llvm::sort(collapsedOpToOrigOpIterationDim,
1534 return lhs[0] < rhs[0];
1535 });
1536 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1537 for (const auto &foldedDims :
1538 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1539 for (const auto &dim : enumerate(foldedDims.value()))
1540 origOpToCollapsedOpIterationDim[dim.value()] =
1541 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1542 }
1543 return success();
1544 }
1545
1546 /// Return mapping from collapsed loop domain to original loop domain.
1547 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1548 return collapsedOpToOrigOpIterationDim;
1549 }
1550
1551 /// Return mapping from original loop domain to collapsed loop domain. The
1552 /// mapping is a pair. First value is the dimension in the collapsed loop that
1553 /// the original loop is mapped to. Second is the relative position in folded
1554 /// list of this domain. For example if the original loop domain is 3D, and
1555 /// the collapsed loop domain is folding all of it, i.e.
1556 ///
1557 /// ```
1558 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1559 /// ```
1560 ///
1561 /// then
1562 ///
1563 /// ```
1564 /// origOpToCollapsedOpMapping[0] = {0, 0};
1565 /// origOpToCollapsedOpMapping[1] = {0, 1};
1566 /// origOpToCollapsedOpMapping[2] = {0, 2};
1567 /// origOpToCollapsedOpMapping[3] = {1, 0};
1568 /// origOpToCollapsedOpMapping[4] = {1, 1};
1569 /// ```
1570 ///
1571 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1572 return origOpToCollapsedOpIterationDim;
1573 }
1574
1575 /// Return the collapsed op iteration domain rank.
1576 unsigned getCollapsedOpIterationRank() const {
1577 return collapsedOpToOrigOpIterationDim.size();
1578 }
1579
1580private:
1581 /// Map from the iteration domain index in collapsed op to the iteration
1582 /// domain indices in the original op.
1583 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1584
1585 /// Map from iteration domain index in the original op to the iteration domain
1586 /// index in the collapsed op.
1587 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1588};
1589} // namespace
1590
1591/// Get the iterator types for the collapsed operation given the original
1592/// iterator types and collapsed dimensions.
1593static SmallVector<utils::IteratorType>
1594getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1595 const CollapsingInfo &collapsingInfo) {
1596 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1597 for (ReassociationIndicesRef foldedIterDims :
1598 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1599 assert(!foldedIterDims.empty() &&
1600 "reassociation indices expected to have non-empty sets");
1601 // Just pick the iterator type of the first folded dim. Pre-condition checks
1602 // expected to have checked that iterator types of all folded dimensions are
1603 // the same.
1604 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1605 }
1606 return collapsedIteratorTypes;
1607}
1608
1609/// Compute the indexing map in the collapsed op that corresponds to the given
1610/// `indexingMap` of the original operation.
1611static AffineMap
1612getCollapsedOpIndexingMap(AffineMap indexingMap,
1613 const CollapsingInfo &collapsingInfo) {
1614 MLIRContext *context = indexingMap.getContext();
1615 assert(indexingMap.isProjectedPermutation() &&
1616 "expected indexing map to be projected permutation");
1617 SmallVector<AffineExpr> resultExprs;
1618 auto origOpToCollapsedOpMapping =
1619 collapsingInfo.getOrigOpToCollapsedOpMapping();
1620 for (auto expr : indexingMap.getResults()) {
1621 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1622 // If the dim is not the first of the collapsed dim, do nothing.
1623 if (origOpToCollapsedOpMapping[dim].second != 0)
1624 continue;
1625 // The next n-dims are guaranteed to be collapsed. So just use the
1626 // iteration dimension of the collapsed op.
1627 resultExprs.push_back(
1628 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1629 }
1630 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1631 resultExprs, context);
1632}
1633
1634/// Return the `reassociation` indices to use to collapse the operand when the
1635/// iteration space of a generic op is collapsed.
1636static SmallVector<ReassociationIndices>
1637getOperandReassociation(AffineMap indexingMap,
1638 const CollapsingInfo &collapsingInfo) {
1639 unsigned counter = 0;
1640 SmallVector<ReassociationIndices> operandReassociation;
1641 auto origOpToCollapsedOpMapping =
1642 collapsingInfo.getOrigOpToCollapsedOpMapping();
1643 auto collapsedOpToOrigOpMapping =
1644 collapsingInfo.getCollapsedOpToOrigOpMapping();
1645 while (counter < indexingMap.getNumResults()) {
1646 unsigned dim =
1647 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1648 // This is the start of a collapsed dimensions of the iteration that
1649 // is gauranteed to be preserved in the indexing map. The number of folded
1650 // dims is obtained from the collapsed op to original op mapping.
1651 unsigned numFoldedDims =
1652 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1653 .size();
1654 if (origOpToCollapsedOpMapping[dim].second == 0) {
1655 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1656 operandReassociation.emplace_back(range.begin(), range.end());
1657 }
1658 counter += numFoldedDims;
1659 }
1660 return operandReassociation;
1661}
1662
1663/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1664static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1665 OpOperand *opOperand,
1666 const CollapsingInfo &collapsingInfo,
1667 OpBuilder &builder) {
1668 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1669 SmallVector<ReassociationIndices> operandReassociation =
1670 getOperandReassociation(indexingMap, collapsingInfo);
1671
1672 // If the number of entries in the reassociation for the operand is same as
1673 // the number of results of the indexing map, then nothing to do for this
1674 // operand.
1675 Value operand = opOperand->get();
1676 if (operandReassociation.size() == indexingMap.getNumResults())
1677 return operand;
1678
1679 // Insert a reshape to collapse the dimensions.
1680 if (isa<MemRefType>(operand.getType())) {
1681 return memref::CollapseShapeOp::create(builder, loc, operand,
1682 operandReassociation)
1683 .getResult();
1684 }
1685 return tensor::CollapseShapeOp::create(builder, loc, operand,
1686 operandReassociation)
1687 .getResult();
1688}
1689
1690/// Modify the `linalg.index` operations in the original generic op, to its
1691/// value in the collapsed operation.
1692static void generateCollapsedIndexingRegion(
1693 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1694 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1695 OpBuilder::InsertionGuard g(rewriter);
1696 rewriter.setInsertionPointToStart(block);
1697
1698 // Collect all the original index ops.
1699 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1700
1701 // For each folded dimension list resolve the original induction variable
1702 // values in terms of the folded dimension induction variable.
1703 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1704 // can be inverted to
1705 // i2 = i_{folded} % d2
1706 // i1 = (i_{folded} / d2) % d1
1707 // i0 = i_{folded} / (d1 * d2)
1708 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1709 for (auto foldedDims :
1710 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1711 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1712 Value newIndexVal =
1713 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1714 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1715 Value loopDim =
1716 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1717 indexReplacementVals[dim] =
1718 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1719 newIndexVal =
1720 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1721 }
1722 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1723 }
1724
1725 for (auto indexOp : indexOps) {
1726 auto dim = indexOp.getDim();
1727 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1728 }
1729}
1730
1731static void collapseOperandsAndResults(LinalgOp op,
1732 const CollapsingInfo &collapsingInfo,
1733 RewriterBase &rewriter,
1734 SmallVectorImpl<Value> &inputOperands,
1735 SmallVectorImpl<Value> &outputOperands,
1736 SmallVectorImpl<Type> &resultTypes) {
1737 Location loc = op->getLoc();
1738 inputOperands =
1739 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1740 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1741 rewriter);
1742 });
1743
1744 // Get the output operands and result types.
1745 resultTypes.reserve(op.getNumDpsInits());
1746 outputOperands.reserve(op.getNumDpsInits());
1747 for (OpOperand &output : op.getDpsInitsMutable()) {
1748 Value newOutput =
1749 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1750 outputOperands.push_back(newOutput);
1751 // If the op has "buffer semantics", then the init operands are ranked
1752 // memrefs and the op has no results.
1753 if (!op.hasPureBufferSemantics())
1754 resultTypes.push_back(newOutput.getType());
1755 }
1756}
1757
1758/// Clone a `LinalgOp` to a collapsed version of same name
1759template <typename OpTy>
1760static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1761 const CollapsingInfo &collapsingInfo) {
1762 return nullptr;
1763}
1764
1765/// Collapse any `LinalgOp` that does not require any specialization such as
1766/// indexing_maps, iterator_types, etc.
1767template <>
1768LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1769 const CollapsingInfo &collapsingInfo) {
1770 SmallVector<Value> inputOperands, outputOperands;
1771 SmallVector<Type> resultTypes;
1772 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1773 outputOperands, resultTypes);
1774
1775 return clone(
1776 rewriter, origOp, resultTypes,
1777 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1778}
1779
1780/// Collapse a `GenericOp`
1781template <>
1782GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1783 GenericOp origOp,
1784 const CollapsingInfo &collapsingInfo) {
1785 SmallVector<Value> inputOperands, outputOperands;
1786 SmallVector<Type> resultTypes;
1787 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1788 outputOperands, resultTypes);
1789 SmallVector<AffineMap> indexingMaps(
1790 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1791 return getCollapsedOpIndexingMap(map, collapsingInfo);
1792 }));
1793
1794 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1795 origOp.getIteratorTypesArray(), collapsingInfo));
1796
1797 GenericOp collapsedOp = linalg::GenericOp::create(
1798 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1799 indexingMaps, iteratorTypes,
1800 [](OpBuilder &builder, Location loc, ValueRange args) {});
1801 Block *origOpBlock = &origOp->getRegion(0).front();
1802 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1803 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1804 collapsedOpBlock->getArguments());
1805 return collapsedOp;
1806}
1807
1808static LinalgOp createCollapsedOp(LinalgOp op,
1809 const CollapsingInfo &collapsingInfo,
1810 RewriterBase &rewriter) {
1811 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1812 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1813 }
1814 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1815}
1816
1817/// Implementation of fusion with reshape operation by collapsing dimensions.
1818FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1819 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1820 RewriterBase &rewriter) {
1821 // Bail on trivial no-op cases.
1822 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1823 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1824 return foldedDims.size() <= 1;
1825 }))
1826 return failure();
1827
1828 CollapsingInfo collapsingInfo;
1829 if (failed(
1830 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1831 return rewriter.notifyMatchFailure(
1832 op, "illegal to collapse specified dimensions");
1833 }
1834
1835 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1836 if (hasPureBufferSemantics &&
1837 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1838 MemRefType memRefToCollapse =
1839 dyn_cast<MemRefType>(opOperand.get().getType());
1840 if (!memRefToCollapse)
1841 return true;
1842
1843 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1844 SmallVector<ReassociationIndices> operandReassociation =
1845 getOperandReassociation(indexingMap, collapsingInfo);
1846 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1847 memRefToCollapse, operandReassociation);
1848 }))
1849 return rewriter.notifyMatchFailure(op,
1850 "memref is not guaranteed collapsible");
1851
1852 // Bail on non-canonical ranges.
1853 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1854 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1855 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1856 return cast<IntegerAttr>(attr).getInt() == value;
1857 llvm::APInt actual;
1858 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1859 actual.getSExtValue() == value;
1860 };
1861 if (!llvm::all_of(loopRanges, [&](Range range) {
1862 return opFoldIsConstantValue(range.offset, 0) &&
1863 opFoldIsConstantValue(range.stride, 1);
1864 })) {
1865 return rewriter.notifyMatchFailure(
1866 op, "expected all loop ranges to have zero start and unit stride");
1867 }
1868
1869 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1870
1871 Location loc = op->getLoc();
1872 SmallVector<OpFoldResult> loopBound =
1873 llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1874
1875 if (collapsedOp.hasIndexSemantics()) {
1876 // Collect the loop range of the generic op.
1877 OpBuilder::InsertionGuard g(rewriter);
1878 rewriter.setInsertionPoint(collapsedOp);
1879 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1880 collapsingInfo, loopBound, rewriter);
1881 }
1882
1883 // Insert expanding reshape for the result to get back the original result
1884 // type.
1885 SmallVector<Value> results;
1886 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1887 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1888 auto originalResultType =
1889 cast<ShapedType>(originalResult.value().getType());
1890 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1891 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1892 AffineMap indexingMap =
1893 op.getIndexingMapMatchingResult(originalResult.value());
1894 SmallVector<ReassociationIndices> reassociation =
1895 getOperandReassociation(indexingMap, collapsingInfo);
1896 assert(
1897 indexingMap.isProjectedPermutation() &&
1898 "Expected indexing map to be a projected permutation for collapsing");
1899 SmallVector<OpFoldResult> resultShape =
1900 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1901 Value result;
1902 if (isa<MemRefType>(collapsedOpResult.getType())) {
1903 result = memref::ExpandShapeOp::create(
1904 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1905 resultShape);
1906 } else {
1907 result = tensor::ExpandShapeOp::create(
1908 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1909 resultShape);
1910 }
1911 results.push_back(result);
1912 } else {
1913 results.push_back(collapsedOpResult);
1914 }
1915 }
1916 return CollapseResult{results, collapsedOp};
1917}
1918
1919namespace {
1920
1921/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1922/// contracting dimensions of the loop.
1923class FoldWithProducerReshapeOpByCollapsing
1924 : public OpRewritePattern<GenericOp> {
1925public:
1926 // TODO : support fusion with all linalg ops, not just generic.
1927 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1928 ControlFusionFn foldReshapes,
1929 PatternBenefit benefit = 1)
1930 : OpRewritePattern<GenericOp>(context, benefit),
1931 controlFoldingReshapes(std::move(foldReshapes)) {}
1932
1933 LogicalResult matchAndRewrite(GenericOp genericOp,
1934 PatternRewriter &rewriter) const override {
1935 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1936 tensor::ExpandShapeOp reshapeOp =
1937 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1938 if (!reshapeOp)
1939 continue;
1940
1941 SmallVector<ReassociationIndices> collapsableIterationDims =
1942 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1943 reshapeOp.getReassociationIndices());
1944 if (collapsableIterationDims.empty() ||
1945 !controlFoldingReshapes(&opOperand)) {
1946 continue;
1947 }
1948
1949 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1950 genericOp, collapsableIterationDims, rewriter);
1951 if (!collapseResult) {
1952 return rewriter.notifyMatchFailure(
1953 genericOp, "failed to do the fusion by collapsing transformation");
1954 }
1955
1956 rewriter.replaceOp(genericOp, collapseResult->results);
1957 return success();
1958 }
1959 return failure();
1960 }
1961
1962private:
1963 ControlFusionFn controlFoldingReshapes;
1964};
1965
1966/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1967/// by expanding the dimensionality of the loop in the producer op.
1968struct FoldReshapeWithGenericOpByCollapsing
1969 : public OpRewritePattern<tensor::CollapseShapeOp> {
1970
1971 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1972 ControlFusionFn foldReshapes,
1973 PatternBenefit benefit = 1)
1974 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1975 controlFoldingReshapes(std::move(foldReshapes)) {}
1976
1977 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1978 PatternRewriter &rewriter) const override {
1979 // Fold only if all constraints of fusing with reshape by collapsing are
1980 // met.
1981 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1982 if (!producerResult) {
1983 return rewriter.notifyMatchFailure(reshapeOp,
1984 "source not produced by an operation");
1985 }
1986
1987 // TODO : support fusion with all linalg producers, not just generic.
1988 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1989 if (!producer) {
1990 return rewriter.notifyMatchFailure(reshapeOp,
1991 "producer not a generic op");
1992 }
1993
1994 SmallVector<ReassociationIndices> collapsableIterationDims =
1996 producer,
1997 producer.getDpsInitOperand(producerResult.getResultNumber()),
1998 reshapeOp.getReassociationIndices());
1999 if (collapsableIterationDims.empty()) {
2000 return rewriter.notifyMatchFailure(
2001 reshapeOp, "failed preconditions of fusion with producer generic op");
2002 }
2003
2004 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2005 return rewriter.notifyMatchFailure(reshapeOp,
2006 "fusion blocked by control function");
2007 }
2008
2009 // Set the insertion point after `producer` because there could be uses
2010 // of `producer` between it and the `tensor.collapse_shape` op.
2011 rewriter.setInsertionPointAfter(producer);
2012 std::optional<CollapseResult> collapseResult =
2013 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
2014 if (!collapseResult) {
2015 return rewriter.notifyMatchFailure(
2016 producer, "failed to do the fusion by collapsing transformation");
2017 }
2018
2019 rewriter.replaceOp(producer, collapseResult->results);
2020 return success();
2021 }
2022
2023private:
2024 ControlFusionFn controlFoldingReshapes;
2025};
2026
2027/// Computes the collapsed padding information for the given pad operation based
2028/// on the provided collapsed shape and reassociation indices. Returns a
2029/// PadDimInfo containing the low and high padding amounts and the collapsed
2030/// shape for each dimension, or failure if the collapse is not possible.
2031static FailureOr<PadDimInfo>
2032computeCollapsedPadding(tensor::PadOp padOp,
2033 ArrayRef<ReassociationIndices> reassociations,
2034 PatternRewriter &rewriter) {
2035 // If the padding value depends on the index values of the pad operation,
2036 // then it may not be valid to collapse the dimensions, since it will change
2037 // the index values on which the padding value depends. This is not currently
2038 // supported by the pad collapsing patterns, but it could be implemented
2039 // similarly to the collapsing of linalg.generic ops with linalg.index ops in
2040 // the body, as is done in `generateCollapsedIndexingRegion`.
2041 if (!padOp.getConstantPaddingValue())
2042 return failure();
2043
2044 // Collapsed dimensions cannot have padding because this can produce strided
2045 // padding that isn't representable by a tensor.pad op. There are some special
2046 // cases where it is possible (like collapsing unit dims), but supporting
2047 // these cases is NYI, so disallow it for now.
2048 ArrayRef<int64_t> low = padOp.getStaticLow();
2049 ArrayRef<int64_t> high = padOp.getStaticHigh();
2050 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2051 for (int64_t dim : reInd) {
2052 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2053 return failure();
2054 }
2055 }
2056
2057 // Initialize padding values for collapsed tensors with zeros
2058 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2059 PadDimInfo padDimInfo;
2060 padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2061 padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2062
2063 // Update padding for dimensions that are not being collapsed, and compute
2064 // the collapsed padded shape.
2065 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2066 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2067 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2068 if (reInd.size() == 1) {
2069 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2070 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2071 }
2072 SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2073 for (int64_t dim : reInd) {
2074 collapsedSize =
2075 collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2076 }
2077 padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2078 }
2079
2080 return padDimInfo;
2081}
2082
2083class FoldPadWithProducerReshapeOpByCollapsing
2084 : public OpRewritePattern<tensor::PadOp> {
2085public:
2086 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2087 ControlFusionFn foldReshapes,
2088 PatternBenefit benefit = 1)
2089 : OpRewritePattern<tensor::PadOp>(context, benefit),
2090 controlFoldingReshapes(std::move(foldReshapes)) {}
2091
2092 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2093 PatternRewriter &rewriter) const override {
2094 tensor::ExpandShapeOp reshapeOp =
2095 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2096 if (!reshapeOp)
2097 return failure();
2098
2099 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2100 return rewriter.notifyMatchFailure(padOp,
2101 "fusion blocked by control function");
2102 }
2103
2104 SmallVector<ReassociationIndices> reassociations =
2105 reshapeOp.getReassociationIndices();
2106 FailureOr<PadDimInfo> maybeCollapsedPadding =
2107 computeCollapsedPadding(padOp, reassociations, rewriter);
2108 if (failed(maybeCollapsedPadding))
2109 return failure();
2110 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2111
2112 SmallVector<OpFoldResult> expandedPaddedSizes =
2113 reshapeOp.getMixedOutputShape();
2114 AffineExpr d0, d1, d2;
2115 bindDims(rewriter.getContext(), d0, d1, d2);
2116 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
2117 Location loc = reshapeOp->getLoc();
2118 for (auto [reInd, l, h] :
2119 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2120 collapsedPadding.highPad)) {
2121 if (reInd.size() == 1) {
2122 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2123 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2124 }
2125 }
2126
2127 RankedTensorType collapsedPaddedType =
2128 padOp.getType().clone(collapsedPadding.paddedShape);
2129 auto newPadOp = tensor::PadOp::create(
2130 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2131 collapsedPadding.lowPad, collapsedPadding.highPad,
2132 padOp.getConstantPaddingValue(), padOp.getNofold());
2133
2134 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2135 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2136 expandedPaddedSizes);
2137
2138 return success();
2139 }
2140
2141private:
2142 ControlFusionFn controlFoldingReshapes;
2143};
2144
2145class FoldReshapeWithProducerPadOpByCollapsing
2146 : public OpRewritePattern<tensor::CollapseShapeOp> {
2147public:
2148 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2149 ControlFusionFn foldReshapes,
2150 PatternBenefit benefit = 1)
2151 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2152 controlFoldingReshapes(std::move(foldReshapes)) {}
2153
2154 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2155 PatternRewriter &rewriter) const override {
2156 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2157 if (!padOp)
2158 return failure();
2159
2160 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2161 return rewriter.notifyMatchFailure(padOp,
2162 "fusion blocked by control function");
2163 }
2164
2165 SmallVector<ReassociationIndices> reassociations =
2166 reshapeOp.getReassociationIndices();
2167 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2168 FailureOr<PadDimInfo> maybeCollapsedPadding =
2169 computeCollapsedPadding(padOp, reassociations, rewriter);
2170 if (failed(maybeCollapsedPadding))
2171 return failure();
2172 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2173
2174 Location loc = reshapeOp->getLoc();
2175 auto newCollapseOp = tensor::CollapseShapeOp::create(
2176 rewriter, loc, padOp.getSource(), reassociations);
2177
2178 auto newPadOp = tensor::PadOp::create(
2179 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2180 collapsedPadding.lowPad, collapsedPadding.highPad,
2181 padOp.getConstantPaddingValue(), padOp.getNofold());
2182
2183 rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2184 return success();
2185 }
2186
2187private:
2188 ControlFusionFn controlFoldingReshapes;
2189};
2190
2191/// Pattern to collapse dimensions.
2192template <typename LinalgType>
2193class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2194public:
2195 CollapseLinalgDimensions(MLIRContext *context,
2196 GetCollapsableDimensionsFn collapseDimensions,
2197 PatternBenefit benefit = 1)
2198 : OpRewritePattern<LinalgType>(context, benefit),
2199 controlCollapseDimension(std::move(collapseDimensions)) {}
2200
2201 LogicalResult matchAndRewrite(LinalgType op,
2202 PatternRewriter &rewriter) const override {
2203 SmallVector<ReassociationIndices> collapsableIterationDims =
2204 controlCollapseDimension(op);
2205 if (collapsableIterationDims.empty())
2206 return failure();
2207
2208 // Check if the specified list of dimensions to collapse is a valid list.
2209 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2210 collapsableIterationDims)) {
2211 return rewriter.notifyMatchFailure(
2212 op, "specified dimensions cannot be collapsed");
2213 }
2214
2215 std::optional<CollapseResult> collapseResult =
2216 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2217 if (!collapseResult) {
2218 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2219 }
2220 rewriter.replaceOp(op, collapseResult->results);
2221 return success();
2222 }
2223
2224private:
2225 GetCollapsableDimensionsFn controlCollapseDimension;
2226};
2227
2228} // namespace
2229
2230//===---------------------------------------------------------------------===//
2231// Methods and patterns that fuse constants with linalg.generic operations.
2232//===---------------------------------------------------------------------===//
2233
2234namespace {
2235/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2236/// handle cases where the constant is not single-valued.
2237class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2238public:
2239 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2240 : OpRewritePattern<GenericOp>(context, benefit) {}
2241
2242 LogicalResult matchAndRewrite(GenericOp genericOp,
2243 PatternRewriter &rewriter) const override {
2244 if (!genericOp.hasPureTensorSemantics())
2245 return failure();
2246 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2247 Operation *def = opOperand->get().getDefiningOp();
2248 TypedAttr constantAttr;
2249 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2250 {
2251 DenseElementsAttr splatAttr;
2252 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2253 splatAttr.isSplat() &&
2254 splatAttr.getType().getElementType().isIntOrFloat()) {
2255 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2256 return true;
2257 }
2258 }
2259 {
2260 IntegerAttr intAttr;
2261 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2262 constantAttr = intAttr;
2263 return true;
2264 }
2265 }
2266 {
2267 FloatAttr floatAttr;
2268 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2269 constantAttr = floatAttr;
2270 return true;
2271 }
2272 }
2273 return false;
2274 };
2275
2276 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2277 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2278 continue;
2279
2280 // The operands and the indexing_maps of the fused operation the same as
2281 // the operands and indexing_maps of the generic operations with the
2282 // values at the constant index dropped.
2283 SmallVector<AffineMap> fusedIndexMaps;
2284 SmallVector<Value> fusedOperands;
2285 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2286 fusedIndexMaps.reserve(genericOp->getNumOperands());
2287 fusedOperands.reserve(genericOp.getNumDpsInputs());
2288 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2289 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2290 if (inputOperand == opOperand)
2291 continue;
2292 Value inputValue = inputOperand->get();
2293 fusedIndexMaps.push_back(
2294 genericOp.getMatchingIndexingMap(inputOperand));
2295 fusedOperands.push_back(inputValue);
2296 fusedLocs.push_back(inputValue.getLoc());
2297 }
2298 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2299 fusedIndexMaps.push_back(
2300 genericOp.getMatchingIndexingMap(&outputOperand));
2301
2302 // Check if the operation shapes to loops map is computable.
2303 if (!inversePermutation(
2304 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2305 return rewriter.notifyMatchFailure(
2306 genericOp, "fused op loop bound computation failed");
2307 }
2308
2309 // Create a constant scalar value from the splat constant.
2310 Value scalarConstant =
2311 arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2312
2313 SmallVector<Value> outputOperands = genericOp.getOutputs();
2314 auto fusedOp =
2315 GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2316 genericOp->getResultTypes(),
2317 /*inputs=*/fusedOperands,
2318 /*outputs=*/outputOperands,
2319 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2320 genericOp.getIteratorTypes(),
2321 /*doc=*/nullptr,
2322 /*library_call=*/nullptr);
2323
2324 // Map the block argument corresponding to the replaced argument with the
2325 // scalar constant.
2326 Region &region = genericOp->getRegion(0);
2327 Block &entryBlock = *region.begin();
2328 IRMapping mapping;
2329 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2330 scalarConstant);
2331 Region &fusedRegion = fusedOp->getRegion(0);
2332 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2333 mapping);
2334 rewriter.replaceOp(genericOp, fusedOp->getResults());
2335 return success();
2336 }
2337 return failure();
2338 }
2339};
2340
2341} // namespace
2342
2343//===---------------------------------------------------------------------===//
2344// Miscellaneous patterns that help fusion.
2345//===---------------------------------------------------------------------===//
2346
2347namespace {
2348/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2349/// value of the `outs` operand is not used within the op. This is only
2350/// implemented for `linalg.generic` operations for now, but should hold for all
2351/// linalg structured ops.
2352struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2353 using OpRewritePattern<GenericOp>::OpRewritePattern;
2354
2355 LogicalResult matchAndRewrite(GenericOp op,
2356 PatternRewriter &rewriter) const override {
2357 rewriter.startOpModification(op);
2358 bool modifiedOutput = false;
2359 Location loc = op.getLoc();
2360 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2361 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2362 Value operandVal = opOperand.get();
2363 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2364 if (!operandType)
2365 continue;
2366
2367 // If outs is sparse, leave it to the sparsifier.
2369 continue;
2370
2371 // If outs is already an `empty` operation, nothing to do.
2372 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2373 if (definingOp)
2374 continue;
2375 modifiedOutput = true;
2376 SmallVector<OpFoldResult> mixedSizes =
2377 tensor::getMixedSizes(rewriter, loc, operandVal);
2378 Value emptyTensor = tensor::EmptyOp::create(
2379 rewriter, loc, mixedSizes, operandType.getElementType());
2380 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2381 }
2382 }
2383 if (!modifiedOutput) {
2384 rewriter.cancelOpModification(op);
2385 return failure();
2386 }
2387 rewriter.finalizeOpModification(op);
2388 return success();
2389 }
2390};
2391
2392/// Fold linalg.fill into linalg.generic
2393struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2394 using OpRewritePattern<GenericOp>::OpRewritePattern;
2395
2396 LogicalResult matchAndRewrite(GenericOp genericOp,
2397 PatternRewriter &rewriter) const override {
2398 if (!genericOp.hasPureTensorSemantics())
2399 return failure();
2400 bool fillFound = false;
2401 Block &payload = genericOp.getRegion().front();
2402 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2403 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2404 continue;
2405 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2406 if (!fillOp)
2407 continue;
2408 fillFound = true;
2409 Value fillVal = fillOp.value();
2410 auto resultType =
2411 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2412 Value convertedVal =
2413 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2414 /*isUnsignedCast =*/false);
2415 rewriter.replaceAllUsesWith(
2416 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2417 }
2418 return success(fillFound);
2419 }
2420};
2421} // namespace
2422
2425 const ControlFusionFn &controlFoldingReshapes) {
2426 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2427 controlFoldingReshapes);
2428 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2429 controlFoldingReshapes);
2430 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2431 controlFoldingReshapes);
2432 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2433 controlFoldingReshapes);
2434}
2435
2438 const ControlFusionFn &controlFoldingReshapes) {
2439 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2440 controlFoldingReshapes);
2441 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2442 patterns.getContext(), controlFoldingReshapes);
2443 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2444 patterns.getContext(), controlFoldingReshapes);
2445 patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2446 controlFoldingReshapes);
2447}
2448
2451 const ControlFusionFn &controlElementwiseOpsFusion) {
2452 auto *context = patterns.getContext();
2453 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2454 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2455 RemoveOutsDependency>(context);
2456 // Add the patterns that clean up dead operands and results.
2458}
2459
2462 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2463 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2464 CollapseLinalgDimensions<linalg::CopyOp>>(
2465 patterns.getContext(), controlCollapseDimensions);
2466}
2467
2468//===---------------------------------------------------------------------===//
2469// Passes
2470//===---------------------------------------------------------------------===//
2471
2472namespace {
2473
2474/// Pass that fuses generic ops on tensors. Used only for testing.
2475// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2476// patterns added here heavily depends on the cost function used. Having an
2477// opinionated pass of this form is not recommended. Deprecate this pass in
2478// favor of test passes that check the functionality of each of the patterns
2479// added here individually.
2480struct LinalgElementwiseOpFusionPass
2481 : public impl::LinalgElementwiseOpFusionPassBase<
2482 LinalgElementwiseOpFusionPass> {
2483 using impl::LinalgElementwiseOpFusionPassBase<
2484 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2485 void runOnOperation() override {
2486 Operation *op = getOperation();
2487 MLIRContext *context = op->getContext();
2488 RewritePatternSet patterns(context);
2489
2490 // Add folding with reshape by expansion patterns.
2491 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2492 Operation *producer = fusedOperand->get().getDefiningOp();
2493 return producer && producer->hasOneUse();
2494 };
2495
2496 // Add elementwise op fusion patterns.
2500
2501 // General canonicalization patterns.
2502 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2503 GenericOp::getCanonicalizationPatterns(patterns, context);
2504 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2505 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2506 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2507 patterns);
2508
2509 // Add constant folding patterns.
2511
2512 // Use TopDownTraversal for compile time reasons.
2513 (void)applyPatternsGreedily(op, std::move(patterns),
2514 GreedyRewriteConfig().setUseTopDownTraversal());
2515 }
2516};
2517
2518} // namespace
return success()
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
This transformation is intended to be used with a top-down traversal (from producer to consumer).
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition Transforms.h:656
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.