MLIR  21.0.0git
LowerVectorBroadcast.cpp
Go to the documentation of this file.
1 //===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.broadcast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-broadcast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 namespace {
31 /// Progressive lowering of BroadcastOp.
32 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
33 public:
35 
36  LogicalResult matchAndRewrite(vector::BroadcastOp op,
37  PatternRewriter &rewriter) const override {
38  auto loc = op.getLoc();
39  VectorType dstType = op.getResultVectorType();
40  VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
41  Type eltType = dstType.getElementType();
42 
43  // Scalar to any vector can use splat.
44  if (!srcType) {
45  rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
46  return success();
47  }
48 
49  // Determine rank of source and destination.
50  int64_t srcRank = srcType.getRank();
51  int64_t dstRank = dstType.getRank();
52 
53  // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
54  if (srcRank <= 1 && dstRank == 1) {
55  Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
56  rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
57  return success();
58  }
59 
60  // Duplicate this rank.
61  // For example:
62  // %x = broadcast %y : k-D to n-D, k < n
63  // becomes:
64  // %b = broadcast %y : k-D to (n-1)-D
65  // %x = [%b,%b,%b,%b] : n-D
66  // becomes:
67  // %b = [%y,%y] : (n-1)-D
68  // %x = [%b,%b,%b,%b] : n-D
69  if (srcRank < dstRank) {
70  // Duplication.
71  VectorType resType = VectorType::Builder(dstType).dropDim(0);
72  Value bcst =
73  rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
74  Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
75  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
76  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
77  rewriter.replaceOp(op, result);
78  return success();
79  }
80 
81  // Find non-matching dimension, if any.
82  assert(srcRank == dstRank);
83  int64_t m = -1;
84  for (int64_t r = 0; r < dstRank; r++)
85  if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
86  m = r;
87  break;
88  }
89 
90  // All trailing dimensions are the same. Simply pass through.
91  if (m == -1) {
92  rewriter.replaceOp(op, op.getSource());
93  return success();
94  }
95 
96  // Any non-matching dimension forces a stretch along this rank.
97  // For example:
98  // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
99  // becomes:
100  // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
101  // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
102  // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
103  // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
104  // %x = [%a,%b,%c,%d]
105  // becomes:
106  // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
107  // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
108  // %a = [%u, %v]
109  // ..
110  // %x = [%a,%b,%c,%d]
111  VectorType resType =
112  VectorType::get(dstType.getShape().drop_front(), eltType,
113  dstType.getScalableDims().drop_front());
114  Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
115  if (m == 0) {
116  // Stetch at start.
117  Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
118  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
119  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
120  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
121  } else {
122  // Stetch not at start.
123  if (dstType.getScalableDims()[0]) {
124  // TODO: For scalable vectors we should emit an scf.for loop.
125  return failure();
126  }
127  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
128  Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
129  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
130  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
131  }
132  }
133  rewriter.replaceOp(op, result);
134  return success();
135  }
136 };
137 } // namespace
138 
141  patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
142 }
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319