MLIR  22.0.0git
LowerVectorBroadcast.cpp
Go to the documentation of this file.
1 //===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.broadcast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-broadcast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 namespace {
31 
32 /// Convert a vector.broadcast with a vector operand to a lower rank
33 /// vector.broadcast. vector.broadcast with a scalar operand is expected to be
34 /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
35 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
36 public:
38 
39  LogicalResult matchAndRewrite(vector::BroadcastOp op,
40  PatternRewriter &rewriter) const override {
41  auto loc = op.getLoc();
42  VectorType dstType = op.getResultVectorType();
43  VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
44  Type eltType = dstType.getElementType();
45 
46  // A broadcast from a scalar is considered to be in the lowered form.
47  if (!srcType)
48  return rewriter.notifyMatchFailure(
49  op, "broadcast from scalar already in lowered form");
50 
51  // Determine rank of source and destination.
52  int64_t srcRank = srcType.getRank();
53  int64_t dstRank = dstType.getRank();
54 
55  // Here we are broadcasting to a rank-1 vector. Ensure that the source is a
56  // scalar.
57  if (srcRank <= 1 && dstRank == 1) {
58  SmallVector<int64_t> fullRankPosition(srcRank, 0);
59  Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
60  fullRankPosition);
61  assert(!isa<VectorType>(ext.getType()) && "expected scalar");
62  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
63  return success();
64  }
65 
66  // Duplicate this rank.
67  // For example:
68  // %x = broadcast %y : k-D to n-D, k < n
69  // becomes:
70  // %b = broadcast %y : k-D to (n-1)-D
71  // %x = [%b,%b,%b,%b] : n-D
72  // becomes:
73  // %b = [%y,%y] : (n-1)-D
74  // %x = [%b,%b,%b,%b] : n-D
75  if (srcRank < dstRank) {
76  // Duplication.
77  VectorType resType = VectorType::Builder(dstType).dropDim(0);
78  Value bcst =
79  vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
80  Value result = ub::PoisonOp::create(rewriter, loc, dstType);
81  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
82  result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
83  rewriter.replaceOp(op, result);
84  return success();
85  }
86 
87  // Find non-matching dimension, if any.
88  assert(srcRank == dstRank);
89  int64_t m = -1;
90  for (int64_t r = 0; r < dstRank; r++)
91  if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
92  m = r;
93  break;
94  }
95 
96  // All trailing dimensions are the same. Simply pass through.
97  if (m == -1) {
98  rewriter.replaceOp(op, op.getSource());
99  return success();
100  }
101 
102  // Any non-matching dimension forces a stretch along this rank.
103  // For example:
104  // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
105  // becomes:
106  // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
107  // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
108  // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
109  // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
110  // %x = [%a,%b,%c,%d]
111  // becomes:
112  // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
113  // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
114  // %a = [%u, %v]
115  // ..
116  // %x = [%a,%b,%c,%d]
117  VectorType resType =
118  VectorType::get(dstType.getShape().drop_front(), eltType,
119  dstType.getScalableDims().drop_front());
120  Value result = ub::PoisonOp::create(rewriter, loc, dstType);
121  if (m == 0) {
122  // Stetch at start.
123  Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
124  Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
125  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
126  result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
127  } else {
128  // Stetch not at start.
129  if (dstType.getScalableDims()[0]) {
130  // TODO: For scalable vectors we should emit an scf.for loop.
131  return failure();
132  }
133  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
134  Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
135  Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
136  result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
137  }
138  }
139  rewriter.replaceOp(op, result);
140  return success();
141  }
142 };
143 } // namespace
144 
147  patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
148 }
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319