MLIR  19.0.0git
LowerVectorInterleave.cpp
Go to the documentation of this file.
1 //===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.interleave' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 #define DEBUG_TYPE "vector-interleave-lowering"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 namespace {
26 
27 /// A one-shot unrolling of vector.interleave to the `targetRank`.
28 ///
29 /// Example:
30 ///
31 /// ```mlir
32 /// vector.interleave %a, %b : vector<1x2x3x4xi64>
33 /// ```
34 /// Would be unrolled to:
35 /// ```mlir
36 /// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37 /// %0 = vector.extract %a[0, 0, 0] ─┐
38 /// : vector<4xi64> from vector<1x2x3x4xi64> |
39 /// %1 = vector.extract %b[0, 0, 0] |
40 /// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41 /// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42 /// %3 = vector.insert %2, %result [0, 0, 0] |
43 /// : vector<8xi64> into vector<1x2x3x8xi64> ┘
44 /// ```
45 ///
46 /// Note: If any leading dimension before the `targetRank` is scalable the
47 /// unrolling will stop before the scalable dimension.
48 class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
49 public:
50  UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
51  PatternBenefit benefit = 1)
52  : OpRewritePattern(context, benefit), targetRank(targetRank){};
53 
54  LogicalResult matchAndRewrite(vector::InterleaveOp op,
55  PatternRewriter &rewriter) const override {
56  VectorType resultType = op.getResultVectorType();
57  auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
58  if (!unrollIterator)
59  return failure();
60 
61  auto loc = op.getLoc();
62  Value result = rewriter.create<arith::ConstantOp>(
63  loc, resultType, rewriter.getZeroAttr(resultType));
64  for (auto position : *unrollIterator) {
65  Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
66  Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
67  Value interleave =
68  rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
69  result = rewriter.create<InsertOp>(loc, interleave, result, position);
70  }
71 
72  rewriter.replaceOp(op, result);
73  return success();
74  }
75 
76 private:
77  int64_t targetRank = 1;
78 };
79 
80 } // namespace
81 
83  RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
84  patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
85 }
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358