MLIR 22.0.0git
LowerVectorBroadcast.cpp
Go to the documentation of this file.
1//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.broadcast' operation.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Location.h"
24
25#define DEBUG_TYPE "vector-broadcast-lowering"
26
27using namespace mlir;
28using namespace mlir::vector;
29
30namespace {
31
32/// Convert a vector.broadcast with a vector operand to a lower rank
33/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
34/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
35class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
36public:
37 using Base::Base;
38
39 LogicalResult matchAndRewrite(vector::BroadcastOp op,
40 PatternRewriter &rewriter) const override {
41 auto loc = op.getLoc();
42 VectorType dstType = op.getResultVectorType();
43 VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
44 Type eltType = dstType.getElementType();
45
46 // A broadcast from a scalar is considered to be in the lowered form.
47 if (!srcType)
48 return rewriter.notifyMatchFailure(
49 op, "broadcast from scalar already in lowered form");
50
51 // Determine rank of source and destination.
52 int64_t srcRank = srcType.getRank();
53 int64_t dstRank = dstType.getRank();
54
55 // Here we are broadcasting to a rank-1 vector. Ensure that the source is a
56 // scalar.
57 if (srcRank <= 1 && dstRank == 1) {
58 SmallVector<int64_t> fullRankPosition(srcRank, 0);
59 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
60 fullRankPosition);
61 assert(!isa<VectorType>(ext.getType()) && "expected scalar");
62 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
63 return success();
64 }
65
66 // Duplicate this rank.
67 // For example:
68 // %x = broadcast %y : k-D to n-D, k < n
69 // becomes:
70 // %b = broadcast %y : k-D to (n-1)-D
71 // %x = [%b,%b,%b,%b] : n-D
72 // becomes:
73 // %b = [%y,%y] : (n-1)-D
74 // %x = [%b,%b,%b,%b] : n-D
75 if (srcRank < dstRank) {
76 // Duplication.
77 VectorType resType = VectorType::Builder(dstType).dropDim(0);
78 Value bcst =
79 vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
80 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
81 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
82 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
83 rewriter.replaceOp(op, result);
84 return success();
85 }
86
87 // Find non-matching dimension, if any.
88 assert(srcRank == dstRank);
89 int64_t m = -1;
90 for (int64_t r = 0; r < dstRank; r++)
91 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
92 m = r;
93 break;
94 }
95
96 // All trailing dimensions are the same. Simply pass through.
97 if (m == -1) {
98 rewriter.replaceOp(op, op.getSource());
99 return success();
100 }
101
102 // Any non-matching dimension forces a stretch along this rank.
103 // For example:
104 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
105 // becomes:
106 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
107 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
108 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
109 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
110 // %x = [%a,%b,%c,%d]
111 // becomes:
112 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
113 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
114 // %a = [%u, %v]
115 // ..
116 // %x = [%a,%b,%c,%d]
117 VectorType resType =
118 VectorType::get(dstType.getShape().drop_front(), eltType,
119 dstType.getScalableDims().drop_front());
120 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
121 if (m == 0) {
122 // Stetch at start.
123 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
124 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
125 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
126 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
127 } else {
128 // Stetch not at start.
129 if (dstType.getScalableDims()[0]) {
130 // TODO: For scalable vectors we should emit an scf.for loop.
131 return failure();
132 }
133 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
134 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
135 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
136 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
137 }
138 }
139 rewriter.replaceOp(op, result);
140 return success();
141 }
142};
143} // namespace
144
147 patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
148}
return success()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...