MLIR 22.0.0git
LowerVectorBitCast.cpp
Go to the documentation of this file.
1//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.bitcast' operation.
11//
12//===----------------------------------------------------------------------===//
13
20
21#define DEBUG_TYPE "vector-bitcast-lowering"
22
23using namespace mlir;
24using namespace mlir::vector;
25
26namespace {
27
28/// A one-shot unrolling of vector.bitcast to the `targetRank`.
29///
30/// Example:
31///
32/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
33///
34/// Would be unrolled to:
35///
36/// %result = ub.poison : vector<1x2x3x8xi32>
37/// %0 = vector.extract %a[0, 0, 0] ─┐
38/// : vector<4xi64> from vector<1x2x3x4xi64> |
39/// %1 = vector.bitcast %0 | - Repeated 6x for
40/// : vector<4xi64> to vector<8xi32> | all leading positions
41/// %2 = vector.insert %1, %result [0, 0, 0] |
42/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
43///
44/// Note: If any leading dimension before the `targetRank` is scalable the
45/// unrolling will stop before the scalable dimension.
46class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
47public:
48 UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
49 PatternBenefit benefit = 1)
50 : OpRewritePattern(context, benefit), targetRank(targetRank) {};
51
52 LogicalResult matchAndRewrite(vector::BitCastOp op,
53 PatternRewriter &rewriter) const override {
54 VectorType resultType = op.getResultVectorType();
55 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
56 if (!unrollIterator)
57 return failure();
58
59 auto unrollRank = unrollIterator->getRank();
60 ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
61 ArrayRef<bool> scalableDims =
62 resultType.getScalableDims().drop_front(unrollRank);
63 auto bitcastResType =
64 VectorType::get(shape, resultType.getElementType(), scalableDims);
65
66 Location loc = op.getLoc();
67 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
68 for (auto position : *unrollIterator) {
69 Value extract =
70 vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
71 Value bitcast =
72 vector::BitCastOp::create(rewriter, loc, bitcastResType, extract);
73 result =
74 vector::InsertOp::create(rewriter, loc, bitcast, result, position);
75 }
76
77 rewriter.replaceOp(op, result);
78 return success();
79 }
80
81private:
82 int64_t targetRank = 1;
83};
84
85} // namespace
86
88 RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
89 patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
90}
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...