MLIR 22.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to remove usage of unit-extent dimensions
10// to specify broadcasting in favor of more canonical representation of the
11// computation
12//
13//===----------------------------------------------------------------------===//
14
16
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
32#include "llvm/Support/Debug.h"
33
34namespace mlir {
35#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
36#include "mlir/Dialect/Linalg/Passes.h.inc"
37} // namespace mlir
38
39#define DEBUG_TYPE "linalg-drop-unit-dims"
40
41using namespace mlir;
42using namespace mlir::linalg;
43
44namespace {
45/// Pattern to move init operands to ins when all the loops are parallel and
46/// blockArgument corresponding to init is used in the region. This is a fix-up
47/// when unit reduction dimensions are all folded away. In this context, it
48/// becomes a elementwise generic op. E.g., it converts
49///
50/// %0 = tensor.empty() : tensor<1x1xf32>
51/// %1 = linalg.fill
52/// ins(%cst : f32)
53/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
54/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
55/// affine_map<(d0) -> (0, d0)>],
56/// iterator_types = ["parallel"]}
57/// ins(%arg0 : tensor<1x?x1x1xf32>)
58/// outs(%1 : tensor<1x1xf32>) {
59/// ^bb0(%in: f32, %out: f32):
60/// %3 = arith.addf %in, %out : f32
61/// linalg.yield %3 : f32
62/// } -> tensor<1x1xf32>
63///
64/// into
65///
66/// %0 = tensor.empty() : tensor<1x1xf32>
67/// %1 = linalg.fill
68/// ins(%cst : f32)
69/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
70/// %2 = tensor.empty() : tensor<1x1xf32>
71/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
72/// affine_map<(d0) -> (0, d0)>,
73/// affine_map<(d0) -> (0, d0)>],
74/// iterator_types = ["parallel"]}
75/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
76/// outs(%2 : tensor<1x1xf32>) {
77/// ^bb0(%in: f32, %in_0: f32, %out: f32):
78/// %4 = arith.addf %in, %in_0 : f32
79/// linalg.yield %4 : f32
80/// } -> tensor<1x1xf32>
81struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
83 LogicalResult matchAndRewrite(GenericOp genericOp,
84 PatternRewriter &rewriter) const override {
85 if (!genericOp.hasPureTensorSemantics())
86 return failure();
87 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
88 return failure();
89
90 auto outputOperands = genericOp.getDpsInitsMutable();
91 SetVector<OpOperand *> candidates;
92 for (OpOperand &op : outputOperands) {
93 if (genericOp.getMatchingBlockArgument(&op).use_empty())
94 continue;
95 candidates.insert(&op);
96 }
97
98 if (candidates.empty())
99 return failure();
100
101 // Compute the modified indexing maps.
102 int64_t origNumInput = genericOp.getNumDpsInputs();
103 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
104 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
105 SmallVector<AffineMap> newIndexingMaps;
106 newIndexingMaps.append(indexingMaps.begin(),
107 std::next(indexingMaps.begin(), origNumInput));
108 for (OpOperand *op : candidates) {
109 newInputOperands.push_back(op->get());
110 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
111 }
112 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
113 indexingMaps.end());
114
115 Location loc = genericOp.getLoc();
116 SmallVector<Value> newOutputOperands =
117 llvm::to_vector(genericOp.getDpsInits());
118 for (OpOperand *op : candidates) {
119 OpBuilder::InsertionGuard guard(rewriter);
120 rewriter.setInsertionPointAfterValue(op->get());
121 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
122 auto empty = tensor::EmptyOp::create(
123 rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()),
124 elemType);
125
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
128 }
129
130 auto newOp = GenericOp::create(
131 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
132 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
133 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
134
135 OpBuilder::InsertionGuard guard(rewriter);
136 Region &region = newOp.getRegion();
137 Block *block = rewriter.createBlock(&region);
138 IRMapping mapper;
139 for (auto bbarg : genericOp.getRegionInputArgs())
140 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
141
142 for (OpOperand *op : candidates) {
143 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
144 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
145 }
146
147 for (OpOperand &op : outputOperands) {
148 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
149 if (candidates.count(&op))
150 block->addArgument(bbarg.getType(), loc);
151 else
152 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
153 }
154
155 for (auto &op : genericOp.getBody()->getOperations()) {
156 rewriter.clone(op, mapper);
157 }
158 rewriter.replaceOp(genericOp, newOp.getResults());
159
160 return success();
161 }
162};
163} // namespace
164
165//===---------------------------------------------------------------------===//
166// Drop loops that are unit-extents within Linalg operations.
167//===---------------------------------------------------------------------===//
168
169/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
170/// broadcasting. For example,
171///
172/// ```mlir
173/// #accesses = [
174/// affine_map<(d0, d1) -> (0, d1)>,
175/// affine_map<(d0, d1) -> (d0, 0)>,
176/// affine_map<(d0, d1) -> (d0, d1)>
177/// ]
178///
179/// #trait = {
180/// indexing_maps = #accesses,
181/// iterator_types = ["parallel", "parallel"],
182/// library_call = "some_external_fn"
183/// }
184///
185/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
186/// tensor<5x5xf32>
187/// {
188/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
189/// tensor<5xf32> into tensor<1x5xf32>
190/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
191/// tensor<5xf32> into tensor<5x1xf32>
192/// %2 = linalg.generic #trait %0, %1 {
193/// ^bb0(%arg2: f32, %arg3: f32):
194/// %3 = arith.addf %arg2, %arg3 : f32
195/// linalg.yield %3 : f32
196/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
197/// return %2 : tensor<5x5xf32>
198/// }
199///
200/// would canonicalize to
201///
202/// ```mlir
203/// #accesses = [
204/// affine_map<(d0, d1) -> (d1)>,
205/// affine_map<(d0, d1) -> (d0)>,
206/// affine_map<(d0, d1) -> (d0, d1)>
207/// ]
208///
209/// #trait = {
210/// indexing_maps = #accesses,
211/// iterator_types = ["parallel", "parallel"],
212/// library_call = "some_external_fn"
213/// }
214///
215/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
216/// tensor<5x5xf32>
217/// {
218/// %0 = linalg.generic #trait %arg0, %arg1 {
219/// ^bb0(%arg2: f32, %arg3: f32):
220/// %3 = arith.addf %arg2, %arg3 : f32
221/// linalg.yield %3 : f32
222/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
223/// return %0 : tensor<5x5xf32>
224/// }
225
226/// Update the index accesses of linalg operations having index semantics.
227static void
228replaceUnitDimIndexOps(GenericOp genericOp,
229 const llvm::SmallDenseSet<unsigned> &unitDims,
230 RewriterBase &rewriter) {
231 for (IndexOp indexOp :
232 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
233 OpBuilder::InsertionGuard guard(rewriter);
234 rewriter.setInsertionPoint(indexOp);
235 if (unitDims.count(indexOp.getDim()) != 0) {
236 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
237 } else {
238 // Update the dimension of the index operation if needed.
239 unsigned droppedDims = llvm::count_if(
240 unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
241 if (droppedDims != 0)
242 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
243 indexOp.getDim() - droppedDims);
244 }
245 }
246}
247
248FailureOr<Value>
249ControlDropUnitDims::expandValue(RewriterBase &rewriter, Location loc,
250 Value result, Value origDest,
251 ArrayRef<ReassociationIndices> reassociation,
252 const ControlDropUnitDims &control) {
253 // There are no results for memref outputs.
254 auto origResultType = cast<RankedTensorType>(origDest.getType());
255 if (origResultType.getEncoding() != nullptr) {
256 // Do not expand tensors with encoding.
257 return failure();
258 }
259 if (control.rankReductionStrategy ==
261 unsigned rank = origResultType.getRank();
262 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
263 SmallVector<OpFoldResult> sizes =
264 tensor::getMixedSizes(rewriter, loc, origDest);
265 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
266 return rewriter.createOrFold<tensor::InsertSliceOp>(
267 loc, result, origDest, offsets, sizes, strides);
268 }
269
270 assert(control.rankReductionStrategy ==
272 "unknown rank reduction strategy");
273 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
274 reassociation)
275 .getResult();
276}
277
278FailureOr<Value>
279ControlDropUnitDims::collapseValue(RewriterBase &rewriter, Location loc,
280 Value operand, ArrayRef<int64_t> targetShape,
281 ArrayRef<ReassociationIndices> reassociation,
282 const ControlDropUnitDims &control) {
283 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
284 if (!memrefType.getLayout().isIdentity()) {
285 // Do not collapse memrefs with a non-identity layout.
286 return failure();
287 }
288 if (control.rankReductionStrategy ==
290 FailureOr<Value> rankReducingExtract =
291 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
292 targetShape);
293 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
294 return *rankReducingExtract;
295 }
296
297 assert(
298 control.rankReductionStrategy ==
300 "unknown rank reduction strategy");
301 MemRefLayoutAttrInterface layout;
302 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
303 layout, memrefType.getMemorySpace());
304 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
305 reassociation)
306 .getResult();
307 }
308 if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
309 if (tensorType.getEncoding() != nullptr) {
310 // Do not collapse tensors with an encoding.
311 return failure();
312 }
313 if (control.rankReductionStrategy ==
315 FailureOr<Value> rankReducingExtract =
316 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
317 targetShape);
318 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
319 return *rankReducingExtract;
320 }
321
322 assert(
323 control.rankReductionStrategy ==
325 "unknown rank reduction strategy");
326 auto targetType =
327 RankedTensorType::get(targetShape, tensorType.getElementType());
328 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
329 reassociation)
330 .getResult();
331 }
332 llvm_unreachable("unsupported operand type");
333}
334
335/// Compute the modified metadata for an operands of operation
336/// whose unit dims are being dropped. Return the new indexing map
337/// to use, the shape of the operand in the replacement op
338/// and the `reassocation` to use to go from original operand shape
339/// to modified operand shape.
346 MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
347 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
348 ArrayRef<AffineExpr> dimReplacements) {
350 ReassociationIndices reassociationGroup;
351 SmallVector<AffineExpr> newIndexExprs;
352 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
353 SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
354 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
355
356 auto isUnitDim = [&](unsigned dim) {
357 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
358 unsigned oldPosition = dimExpr.getPosition();
359 return !oldDimsToNewDimsMap.count(oldPosition) &&
360 (operandShape[dim] == 1);
361 }
362 // Handle the other case where the shape is 1, and is accessed using a
363 // constant 0.
364 if (operandShape[dim] == 1) {
365 // Use the new expression after replacing dimensions that will be dropped
366 // here to handle cases where an affine expression with multiple
367 // dimensions (e.g., `d0 + d2`) can be simplified to 0 after dropping all
368 // dimensions used in the expression (`d0` and `d2` in this example).
369 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
370 auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
371 return constAffineExpr && constAffineExpr.getValue() == 0;
372 }
373 return false;
374 };
375
376 unsigned dim = 0;
377 while (dim < operandShape.size() && isUnitDim(dim))
378 reassociationGroup.push_back(dim++);
379 while (dim < operandShape.size()) {
380 assert(!isUnitDim(dim) && "expected non unit-extent");
381 reassociationGroup.push_back(dim);
382 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
383 newIndexExprs.push_back(newExpr);
384 info.targetShape.push_back(operandShape[dim]);
385 ++dim;
386 // Fold all following dimensions that are unit-extent.
387 while (dim < operandShape.size() && isUnitDim(dim)) {
388 reassociationGroup.push_back(dim++);
389 }
390 info.reassociation.push_back(reassociationGroup);
391 reassociationGroup.clear();
392 }
393 info.indexMap =
394 AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
395 newIndexExprs, context);
396 return info;
397}
398
399FailureOr<DropUnitDimsResult>
400linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
401 const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
403 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
404 if (!dpsOp) {
405 return rewriter.notifyMatchFailure(
406 op, "op should implement DestinationStyleOpInterface");
407 }
408
409 SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
410 if (indexingMaps.empty())
411 return failure();
412
413 // 1. Check if any of the iteration dimensions are unit-trip count. They will
414 // end up being unit-trip count if they are used to index into a unit-dim
415 // tensor/memref.
416 AffineMap invertedMap =
417 inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
418 if (!invertedMap) {
419 return rewriter.notifyMatchFailure(op,
420 "invalid indexing maps for operation");
421 }
422
423 SmallVector<int64_t> allShapesSizes;
424 for (OpOperand &opOperand : op->getOpOperands())
425 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
426
427 // 1a. Get the allowed list of dimensions to drop from the `options`.
428 SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
429 if (allowedUnitDims.empty()) {
430 return rewriter.notifyMatchFailure(
431 op, "control function returns no allowed unit dims to prune");
432 }
433 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
434 allowedUnitDims.end());
435 llvm::SmallDenseSet<unsigned> unitDims;
436 for (const auto &expr : enumerate(invertedMap.getResults())) {
437 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
438 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
439 unitDimsFilter.count(expr.index()))
440 unitDims.insert(expr.index());
441 }
442 }
443
444 // 2. Compute the new loops of the modified op by dropping the one-trip
445 // count loops.
446 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
447 SmallVector<AffineExpr> dimReplacements;
448 unsigned newDims = 0;
449 for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
450 if (unitDims.count(index)) {
451 dimReplacements.push_back(
452 getAffineConstantExpr(0, rewriter.getContext()));
453 } else {
454 oldDimToNewDimMap[index] = newDims;
455 dimReplacements.push_back(
456 getAffineDimExpr(newDims, rewriter.getContext()));
457 newDims++;
458 }
459 }
460
461 // 3. For each of the operands, find the
462 // - modified affine map to use.
463 // - shape of the operands after the unit-dims are dropped.
464 // - the reassociation indices used to convert from the original
465 // operand type to modified operand (needed only when using reshapes
466 // for rank reduction strategy)
467 // Note that the indexing maps might need changing even if there are no
468 // unit dimensions that are dropped to handle cases where `0` is used to
469 // access a unit-extent tensor. Consider moving this out of this specific
470 // transformation as a stand-alone transformation. Kept here right now due
471 // to legacy.
472 SmallVector<AffineMap> newIndexingMaps;
475 SmallVector<bool> collapsed;
476 for (OpOperand &opOperand : op->getOpOperands()) {
477 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
478 auto replacementInfo =
479 dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
480 oldDimToNewDimMap, dimReplacements);
481 reassociations.push_back(replacementInfo.reassociation);
482 newIndexingMaps.push_back(replacementInfo.indexMap);
483 targetShapes.push_back(replacementInfo.targetShape);
484 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
485 indexingMap.getNumResults()));
486 }
487
488 // Abort if the indexing maps of the result operation are not invertible
489 // (i.e. not legal) or if no dimension was reduced.
490 if (newIndexingMaps == indexingMaps ||
492 concatAffineMaps(newIndexingMaps, rewriter.getContext())))
493 return failure();
494
495 Location loc = op.getLoc();
496 // 4. For each of the operands, collapse the operand to convert
497 // from original shape to shape in the modified operation if needed,
498 // either through use of reshapes or rank-reducing slices as
499 // specified in `options`.
500 // Abort if one of the operands cannot be collapsed.
501 SmallVector<Value> newOperands;
502 for (OpOperand &opOperand : op->getOpOperands()) {
503 int64_t idx = opOperand.getOperandNumber();
504 if (!collapsed[idx]) {
505 newOperands.push_back(opOperand.get());
506 continue;
507 }
508 FailureOr<Value> collapsed =
509 options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
510 reassociations[idx], options);
511 if (failed(collapsed)) {
512 // Abort if the operand could not be collapsed.
513 return failure();
514 }
515 newOperands.push_back(collapsed.value());
516 }
517
518 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
519 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
520
521 // 6. If any result type changes, insert a reshape/slice to convert from the
522 // original type to the new type.
523 // Abort the transformation if the result cannot be expanded back to its
524 // original shape.
525 SmallVector<Value> resultReplacements;
526 for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
527 unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
528 Value origDest = dpsOp.getDpsInitOperand(index)->get();
529 if (!collapsed[opOperandIndex]) {
530 resultReplacements.push_back(result);
531 continue;
532 }
533 FailureOr<Value> expanded =
534 options.expandFn(rewriter, loc, result, origDest,
535 reassociations[opOperandIndex], options);
536 if (failed(expanded)) {
537 // Abort if expansion is not successful.
538 return failure();
539 }
540 resultReplacements.push_back(expanded.value());
541 }
542
543 return DropUnitDimsResult{replacementOp, resultReplacements};
544}
545
546FailureOr<DropUnitDimsResult>
547linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
549
551 [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
552 ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
553 const llvm::SmallDenseSet<unsigned> &droppedDims)
554 -> IndexingMapOpInterface {
555 auto genericOp = cast<GenericOp>(op);
556 // Compute the iterator types of the modified op by dropping the one-trip
557 // count loops.
558 SmallVector<utils::IteratorType> newIteratorTypes;
559 for (auto [index, attr] :
560 llvm::enumerate(genericOp.getIteratorTypesArray())) {
561 if (!droppedDims.count(index))
562 newIteratorTypes.push_back(attr);
563 }
564
565 // Create the `linalg.generic` operation with the new operands,
566 // indexing maps, iterator types and result types.
567 ArrayRef<Value> newInputs =
568 ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
569 ArrayRef<Value> newOutputs =
570 ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
571 SmallVector<Type> resultTypes;
572 resultTypes.reserve(genericOp.getNumResults());
573 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
574 resultTypes.push_back(newOutputs[i].getType());
575 GenericOp replacementOp =
576 GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
577 newIndexingMaps, newIteratorTypes);
578 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
579 replacementOp.getRegion().begin());
580 // 5a. Replace `linalg.index` operations that refer to the dropped unit
581 // dimensions.
582 IRRewriter rewriter(b);
583 replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
584
585 return replacementOp;
586 };
587
588 return dropUnitDims(rewriter, genericOp, build, options);
589}
590
591namespace {
592struct DropUnitDims : public OpRewritePattern<GenericOp> {
593 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
594 PatternBenefit benefit = 1)
595 : OpRewritePattern(context, benefit), options(std::move(options)) {}
596
597 LogicalResult matchAndRewrite(GenericOp genericOp,
598 PatternRewriter &rewriter) const override {
599 FailureOr<DropUnitDimsResult> result =
600 dropUnitDims(rewriter, genericOp, options);
601 if (failed(result)) {
602 return failure();
603 }
604 rewriter.replaceOp(genericOp, result->replacements);
605 return success();
606 }
607
608private:
609 ControlDropUnitDims options;
610};
611} // namespace
612
613//===---------------------------------------------------------------------===//
614// Drop dimensions that are unit-extents within tensor operations.
615//===---------------------------------------------------------------------===//
616
617namespace {
618struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
619 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
620 PatternBenefit benefit = 1)
621 : OpRewritePattern(context, benefit), options(std::move(options)) {}
622
623 LogicalResult matchAndRewrite(tensor::PadOp padOp,
624 PatternRewriter &rewriter) const override {
625 // 1a. Get the allowed list of dimensions to drop from the `options`.
626 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
627 if (allowedUnitDims.empty()) {
628 return rewriter.notifyMatchFailure(
629 padOp, "control function returns no allowed unit dims to prune");
630 }
631
632 if (padOp.getSourceType().getEncoding()) {
633 return rewriter.notifyMatchFailure(
634 padOp, "cannot collapse dims of tensor with encoding");
635 }
636
637 // Fail for non-constant padding values. The body of the pad could
638 // depend on the padding indices and/or properties of the padded
639 // tensor so for now we fail.
640 // TODO: Support non-constant padding values.
641 Value paddingVal = padOp.getConstantPaddingValue();
642 if (!paddingVal) {
643 return rewriter.notifyMatchFailure(
644 padOp, "unimplemented: non-constant padding value");
645 }
646
647 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
648 ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
649 int64_t padRank = sourceShape.size();
650
651 auto isStaticZero = [](OpFoldResult f) {
652 return getConstantIntValue(f) == 0;
653 };
654
655 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
656 allowedUnitDims.end());
657 llvm::SmallDenseSet<unsigned> unitDims;
658 SmallVector<int64_t> newShape;
659 SmallVector<int64_t> newResultShape;
660 SmallVector<OpFoldResult> newLowPad;
661 SmallVector<OpFoldResult> newHighPad;
662 for (const auto [dim, size, outSize, low, high] : zip_equal(
663 llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
664 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
665 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
666 isStaticZero(high)) {
667 unitDims.insert(dim);
668 } else {
669 newShape.push_back(size);
670 newResultShape.push_back(outSize);
671 newLowPad.push_back(low);
672 newHighPad.push_back(high);
673 }
674 }
675
676 if (unitDims.empty()) {
677 return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
678 }
679
680 ReassociationIndices reassociationGroup;
681 SmallVector<ReassociationIndices> reassociationMap;
682 int64_t dim = 0;
683 while (dim < padRank && unitDims.contains(dim))
684 reassociationGroup.push_back(dim++);
685 while (dim < padRank) {
686 assert(!unitDims.contains(dim) && "expected non unit-extent");
687 reassociationGroup.push_back(dim);
688 dim++;
689 // Fold all following dimensions that are unit-extent.
690 while (dim < padRank && unitDims.contains(dim))
691 reassociationGroup.push_back(dim++);
692 reassociationMap.push_back(reassociationGroup);
693 reassociationGroup.clear();
694 }
695
696 FailureOr<Value> collapsedSource =
697 options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
698 newShape, reassociationMap, options);
699 if (failed(collapsedSource)) {
700 return rewriter.notifyMatchFailure(padOp, "Failed to collapse source");
701 }
702
703 auto newResultType = RankedTensorType::get(
704 newResultShape, padOp.getResultType().getElementType());
705 auto newPadOp = tensor::PadOp::create(
706 rewriter, padOp.getLoc(), /*result=*/newResultType,
707 collapsedSource.value(), newLowPad, newHighPad, paddingVal,
708 padOp.getNofold());
709
710 Value dest = padOp.getResult();
711 if (options.rankReductionStrategy ==
712 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
713 SmallVector<OpFoldResult> expandedSizes;
714 int64_t numUnitDims = 0;
715 for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
716 if (unitDims.contains(dim)) {
717 expandedSizes.push_back(rewriter.getIndexAttr(1));
718 numUnitDims++;
719 continue;
721 expandedSizes.push_back(tensor::getMixedSize(
722 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
724 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
725 padOp.getResultType().getElementType());
728 FailureOr<Value> expandedValue =
729 options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
730 reassociationMap, options);
731 if (failed(expandedValue)) {
732 return rewriter.notifyMatchFailure(padOp, "Failed to expand result");
734 rewriter.replaceOp(padOp, expandedValue.value());
735 return success();
736 }
737
738private:
740};
741} // namespace
742
743namespace {
744/// Convert `extract_slice` operations to rank-reduced versions.
745struct RankReducedExtractSliceOp
746 : public OpRewritePattern<tensor::ExtractSliceOp> {
747 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
748
749 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
750 PatternRewriter &rewriter) const override {
751 RankedTensorType resultType = sliceOp.getType();
752 SmallVector<OpFoldResult> targetShape;
753 for (auto size : resultType.getShape())
754 targetShape.push_back(rewriter.getIndexAttr(size));
755 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
756 if (!reassociation ||
757 reassociation->size() == static_cast<size_t>(resultType.getRank()))
758 return failure();
759
760 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
761 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
762 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
763 auto rankReducedType = cast<RankedTensorType>(
764 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
765 reassociation->size(), sliceOp.getSourceType(), sizes));
766
767 Location loc = sliceOp.getLoc();
768 Value newSlice = tensor::ExtractSliceOp::create(
769 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
770 strides);
771 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
772 sliceOp, resultType, newSlice, *reassociation);
773 return success();
774 }
775};
777/// Convert `insert_slice` operations to rank-reduced versions.
778/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
779template <typename InsertOpTy>
780struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
781 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
783 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
784 PatternRewriter &rewriter) const override {
785 RankedTensorType sourceType = insertSliceOp.getSourceType();
787 for (auto size : sourceType.getShape())
788 targetShape.push_back(rewriter.getIndexAttr(size));
789 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
790 if (!reassociation ||
791 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
792 return failure();
793
794 Location loc = insertSliceOp.getLoc();
795 tensor::CollapseShapeOp reshapedSource;
796 {
797 OpBuilder::InsertionGuard g(rewriter);
798 // The only difference between InsertSliceOp and ParallelInsertSliceOp
799 // is the insertion point is just before the ParallelCombiningOp in the
800 // parallel case.
801 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
802 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
803 reshapedSource = tensor::CollapseShapeOp::create(
804 rewriter, loc, insertSliceOp.getSource(), *reassociation);
805 }
806 rewriter.replaceOpWithNewOp<InsertOpTy>(
807 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
808 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
809 insertSliceOp.getMixedStrides());
810 return success();
811 }
812};
813} // namespace
814
815/// Patterns that are used to canonicalize the use of unit-extent dims for
816/// broadcasting.
819 auto *context = patterns.getContext();
820 patterns.add<DropUnitDims>(context, options);
821 patterns.add<DropPadUnitDims>(context, options);
822}
823
826 auto *context = patterns.getContext();
827 bool reassociativeReshape =
828 options.rankReductionStrategy ==
830 if (reassociativeReshape) {
831 patterns.add<RankReducedExtractSliceOp,
832 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
833 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
834 context);
835 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
836 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
837 }
838 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
839 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
843}
844
847 patterns.add<MoveInitOperandsToInput>(patterns.getContext());
848}
849
850namespace {
851/// Pass that removes unit-extent dims within generic ops.
852struct LinalgFoldUnitExtentDimsPass
854 LinalgFoldUnitExtentDimsPass> {
856 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
857 void runOnOperation() override {
858 Operation *op = getOperation();
859 MLIRContext *context = op->getContext();
861 if (useRankReducingSlices) {
862 options.rankReductionStrategy = linalg::ControlDropUnitDims::
863 RankReductionStrategy::ExtractInsertSlice;
864 }
865
866 // Apply fold unit extent dims patterns with walk-based driver.
867 {
870 walkAndApplyPatterns(op, std::move(patterns));
871 }
872
873 // Apply canonicalization patterns with greedy driver.
874 {
875 RewritePatternSet patterns(context);
878 options);
879 (void)applyPatternsGreedily(op, std::move(patterns));
880 }
881 }
882};
883
884} // namespace
885
886namespace {
887
888/// Returns reassociation indices for collapsing/expanding a
889/// tensor of rank `rank` at position `pos`.
890static SmallVector<ReassociationIndices>
891getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
892 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
893 bool lastDim = pos == rank - 1;
894 if (rank > 2) {
895 for (int64_t i = 0; i < rank - 1; i++) {
896 if (i == pos || (lastDim && i == pos - 1))
897 reassociation[i] = ReassociationIndices{i, i + 1};
898 else if (i < pos)
899 reassociation[i] = ReassociationIndices{i};
900 else
901 reassociation[i] = ReassociationIndices{i + 1};
902 }
903 }
904 return reassociation;
905}
906
907/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
908/// If `pos < 0`, then don't collapse.
909static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
910 int64_t pos) {
911 if (pos < 0)
912 return val;
913 auto valType = cast<ShapedType>(val.getType());
914 SmallVector<int64_t> collapsedShape(valType.getShape());
915 collapsedShape.erase(collapsedShape.begin() + pos);
916 ControlDropUnitDims control{};
917 FailureOr<Value> collapsed = control.collapseFn(
918 rewriter, val.getLoc(), val, collapsedShape,
919 getReassociationForReshapeAtDim(valType.getRank(), pos), control);
920 assert(llvm::succeeded(collapsed) && "Collapsing the value failed");
921 return collapsed.value();
922}
923
924/// Base class for all rank reduction patterns for contraction ops
925/// with unit dimensions. All patterns should convert one named op
926/// to another named op. Intended to reduce only one iteration space dim
927/// at a time.
928/// Reducing multiple dims will happen with recusive application of
929/// pattern rewrites.
930template <typename FromOpTy, typename ToOpTy>
931struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
932 using OpRewritePattern<FromOpTy>::OpRewritePattern;
933
934 /// Collapse all collapsable operands.
935 SmallVector<Value>
936 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
937 ArrayRef<int64_t> operandCollapseDims) const {
938 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
939 "expected 3 operands and dims");
940 return llvm::map_to_vector(
941 llvm::zip(operands, operandCollapseDims), [&](auto pair) {
942 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
943 std::get<1>(pair));
944 });
945 }
946
947 /// Expand result tensor.
948 Value expandResult(PatternRewriter &rewriter, Value result,
949 RankedTensorType expandedType, int64_t dim) const {
950 return tensor::ExpandShapeOp::create(
951 rewriter, result.getLoc(), expandedType, result,
952 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
953 }
954
955 LogicalResult matchAndRewrite(FromOpTy contractionOp,
956 PatternRewriter &rewriter) const override {
957 if (contractionOp.hasUserDefinedMaps()) {
958 return rewriter.notifyMatchFailure(
959 contractionOp, "ops with user-defined maps are not supported");
960 }
961
962 auto loc = contractionOp.getLoc();
963 auto inputs = contractionOp.getDpsInputs();
964 auto inits = contractionOp.getDpsInits();
965 if (inputs.size() != 2 || inits.size() != 1)
966 return rewriter.notifyMatchFailure(contractionOp,
967 "expected 2 inputs and 1 init");
968 auto lhs = inputs[0];
969 auto rhs = inputs[1];
970 auto init = inits[0];
971 SmallVector<Value> operands{lhs, rhs, init};
972
973 SmallVector<int64_t> operandUnitDims;
974 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
975 return rewriter.notifyMatchFailure(contractionOp,
976 "no reducable dims found");
977
978 SmallVector<Value> collapsedOperands =
979 collapseOperands(rewriter, operands, operandUnitDims);
980 Value collapsedLhs = collapsedOperands[0];
981 Value collapsedRhs = collapsedOperands[1];
982 Value collapsedInit = collapsedOperands[2];
983 SmallVector<Type, 1> collapsedResultTy;
984 if (isa<RankedTensorType>(collapsedInit.getType()))
985 collapsedResultTy.push_back(collapsedInit.getType());
986 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
987 ValueRange{collapsedLhs, collapsedRhs},
988 ValueRange{collapsedInit});
989 for (auto attr : contractionOp->getAttrs()) {
990 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
991 attr.getName() == "indexing_maps")
992 continue;
993 collapsedOp->setAttr(attr.getName(), attr.getValue());
994 }
995
996 auto results = contractionOp.getResults();
997 assert(results.size() < 2 && "expected at most one result");
998 if (results.empty()) {
999 rewriter.replaceOp(contractionOp, collapsedOp);
1000 } else {
1001 rewriter.replaceOp(
1002 contractionOp,
1003 expandResult(rewriter, collapsedOp.getResultTensors()[0],
1004 cast<RankedTensorType>(results[0].getType()),
1005 operandUnitDims[2]));
1006 }
1007
1008 return success();
1009 }
1010
1011 /// Populate `operandUnitDims` with 3 indices indicating the unit dim
1012 /// for each operand that should be collapsed in this pattern. If an
1013 /// operand shouldn't be collapsed, the index should be negative.
1014 virtual LogicalResult
1015 getOperandUnitDims(LinalgOp op,
1016 SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
1017};
1018
1019/// Patterns for unbatching batched contraction ops
1020template <typename FromOpTy, typename ToOpTy>
1021struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1022 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1023
1024 /// Look for unit batch dims to collapse.
1025 LogicalResult
1026 getOperandUnitDims(LinalgOp op,
1027 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1028 FailureOr<ContractionDimensions> maybeContractionDims =
1030 if (failed(maybeContractionDims)) {
1031 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1032 return failure();
1033 }
1034 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1035
1036 if (contractionDims.batch.size() != 1)
1037 return failure();
1038 auto batchDim = contractionDims.batch[0];
1039 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
1040 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1041 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
1042 return cast<ShapedType>(std::get<0>(pair).getType())
1043 .getShape()[std::get<1>(pair)] != 1;
1044 })) {
1045 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1046 return failure();
1047 }
1048
1049 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1050 std::get<1>(bOperands[1]),
1051 std::get<1>(bOperands[2])};
1052 return success();
1053 }
1054};
1055
1056/// Patterns for reducing non-batch dimensions
1057template <typename FromOpTy, typename ToOpTy>
1058struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1059 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1060
1061 /// Helper for determining whether the lhs/init or rhs/init are reduced.
1062 static bool constexpr reduceLeft =
1063 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1064 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1065 (std::is_same_v<FromOpTy, MatmulOp> &&
1066 std::is_same_v<ToOpTy, VecmatOp>) ||
1067 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1068
1069 /// Look for non-batch spatial dims to collapse.
1070 LogicalResult
1071 getOperandUnitDims(LinalgOp op,
1072 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1073 FailureOr<ContractionDimensions> maybeContractionDims =
1075 if (failed(maybeContractionDims)) {
1076 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1077 return failure();
1078 }
1079 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1080
1081 if constexpr (reduceLeft) {
1082 auto m = contractionDims.m[0];
1083 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1084 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1085 if (mOperands.size() != 2)
1086 return failure();
1087 if (llvm::all_of(mOperands, [](auto pair) {
1088 return cast<ShapedType>(std::get<0>(pair).getType())
1089 .getShape()[std::get<1>(pair)] == 1;
1090 })) {
1091 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1092 std::get<1>(mOperands[1])};
1093 return success();
1094 }
1095 } else {
1096 auto n = contractionDims.n[0];
1097 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1098 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1099 if (nOperands.size() != 2)
1100 return failure();
1101 if (llvm::all_of(nOperands, [](auto pair) {
1102 return cast<ShapedType>(std::get<0>(pair).getType())
1103 .getShape()[std::get<1>(pair)] == 1;
1104 })) {
1105 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1106 std::get<1>(nOperands[1])};
1107 return success();
1108 }
1109 }
1110 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1111 return failure();
1112 }
1113};
1114
1115} // namespace
1116
1119 MLIRContext *context = patterns.getContext();
1120 // Unbatching patterns for unit batch size
1121 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1122 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1123 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1124
1125 // Non-batch rank 1 reducing patterns
1126 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1127 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1128 // Batch rank 1 reducing patterns
1129 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1130 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1131
1132 // Non-batch rank 0 reducing patterns
1133 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1134 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1135}
return success()
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
std::function< IndexingMapOpInterface( Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
Definition Transforms.h:629
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2555
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
Drop unit extent dimensions from the op and its operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:388
void populateFoldUnitExtentDimsCanonicalizationPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Populates canonicalization patterns that simplify IR after folding unit-extent dimensions.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
SmallVector< unsigned, 2 > batch
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
RankReductionStrategy rankReductionStrategy
Definition Transforms.h:524
CollapseFnTy collapseFn
Function to control how operands are collapsed into their new target shape after dropping unit extent...
Definition Transforms.h:568