MLIR  20.0.0git
LowerVectorBitCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.bitcast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 #define DEBUG_TYPE "vector-bitcast-lowering"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 namespace {
26 
27 /// A one-shot unrolling of vector.bitcast to the `targetRank`.
28 ///
29 /// Example:
30 ///
31 /// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
32 ///
33 /// Would be unrolled to:
34 ///
35 /// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
36 /// %0 = vector.extract %a[0, 0, 0] ─┐
37 /// : vector<4xi64> from vector<1x2x3x4xi64> |
38 /// %1 = vector.bitcast %0 | - Repeated 6x for
39 /// : vector<4xi64> to vector<8xi32> | all leading positions
40 /// %2 = vector.insert %1, %result [0, 0, 0] |
41 /// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
42 ///
43 /// Note: If any leading dimension before the `targetRank` is scalable the
44 /// unrolling will stop before the scalable dimension.
45 class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
46 public:
47  UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
48  PatternBenefit benefit = 1)
49  : OpRewritePattern(context, benefit), targetRank(targetRank) {};
50 
51  LogicalResult matchAndRewrite(vector::BitCastOp op,
52  PatternRewriter &rewriter) const override {
53  VectorType resultType = op.getResultVectorType();
54  auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
55  if (!unrollIterator)
56  return failure();
57 
58  auto unrollRank = unrollIterator->getRank();
59  ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
60  ArrayRef<bool> scalableDims =
61  resultType.getScalableDims().drop_front(unrollRank);
62  auto bitcastResType =
63  VectorType::get(shape, resultType.getElementType(), scalableDims);
64 
65  Location loc = op.getLoc();
66  Value result = rewriter.create<arith::ConstantOp>(
67  loc, resultType, rewriter.getZeroAttr(resultType));
68  for (auto position : *unrollIterator) {
69  Value extract =
70  rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
71  Value bitcast =
72  rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
73  result =
74  rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
75  }
76 
77  rewriter.replaceOp(op, result);
78  return success();
79  }
80 
81 private:
82  int64_t targetRank = 1;
83 };
84 
85 } // namespace
86 
88  RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
89  patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
90 }
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358