MLIR  14.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "../PassDetail.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35  : public OpConversionPattern<vector::BitCastOp> {
37 
39  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
40  ConversionPatternRewriter &rewriter) const override {
41  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42  if (!dstType)
43  return failure();
44 
45  if (dstType == adaptor.source().getType())
46  rewriter.replaceOp(bitcastOp, adaptor.source());
47  else
48  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
49  adaptor.source());
50 
51  return success();
52  }
53 };
54 
55 struct VectorBroadcastConvert final
56  : public OpConversionPattern<vector::BroadcastOp> {
58 
60  matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
61  ConversionPatternRewriter &rewriter) const override {
62  if (broadcastOp.source().getType().isa<VectorType>() ||
63  !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
64  return failure();
65  SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
66  adaptor.source());
67  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
68  broadcastOp, broadcastOp.getVectorType(), source);
69  return success();
70  }
71 };
72 
73 struct VectorExtractOpConvert final
74  : public OpConversionPattern<vector::ExtractOp> {
76 
78  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80  // Only support extracting a scalar value now.
81  VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
82  if (resultVectorType && resultVectorType.getNumElements() > 1)
83  return failure();
84 
85  auto dstType = getTypeConverter()->convertType(extractOp.getType());
86  if (!dstType)
87  return failure();
88 
89  if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
90  rewriter.replaceOp(extractOp, adaptor.vector());
91  return success();
92  }
93 
94  int32_t id = getFirstIntValue(extractOp.position());
95  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
96  extractOp, adaptor.vector(), id);
97  return success();
98  }
99 };
100 
101 struct VectorExtractStridedSliceOpConvert final
102  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
104 
106  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
107  ConversionPatternRewriter &rewriter) const override {
108  auto dstType = getTypeConverter()->convertType(extractOp.getType());
109  if (!dstType)
110  return failure();
111 
112 
113  uint64_t offset = getFirstIntValue(extractOp.offsets());
114  uint64_t size = getFirstIntValue(extractOp.sizes());
115  uint64_t stride = getFirstIntValue(extractOp.strides());
116  if (stride != 1)
117  return failure();
118 
119  Value srcVector = adaptor.getOperands().front();
120 
121  // Extract vector<1xT> case.
122  if (dstType.isa<spirv::ScalarType>()) {
123  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
124  srcVector, offset);
125  return success();
126  }
127 
128  SmallVector<int32_t, 2> indices(size);
129  std::iota(indices.begin(), indices.end(), offset);
130 
131  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
132  extractOp, dstType, srcVector, srcVector,
133  rewriter.getI32ArrayAttr(indices));
134 
135  return success();
136  }
137 };
138 
139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
141 
143  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
144  ConversionPatternRewriter &rewriter) const override {
145  if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
146  return failure();
147  rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
148  fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
149  return success();
150  }
151 };
152 
153 struct VectorInsertOpConvert final
154  : public OpConversionPattern<vector::InsertOp> {
156 
158  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
159  ConversionPatternRewriter &rewriter) const override {
160  if (insertOp.getSourceType().isa<VectorType>() ||
161  !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
162  return failure();
163  int32_t id = getFirstIntValue(insertOp.position());
164  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
165  insertOp, adaptor.source(), adaptor.dest(), id);
166  return success();
167  }
168 };
169 
170 struct VectorExtractElementOpConvert final
171  : public OpConversionPattern<vector::ExtractElementOp> {
173 
175  matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const override {
177  if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
178  return failure();
179  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
180  extractElementOp, extractElementOp.getType(), adaptor.vector(),
181  extractElementOp.position());
182  return success();
183  }
184 };
185 
186 struct VectorInsertElementOpConvert final
187  : public OpConversionPattern<vector::InsertElementOp> {
189 
191  matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override {
193  if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
194  return failure();
195  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
196  insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
197  adaptor.source(), insertElementOp.position());
198  return success();
199  }
200 };
201 
202 struct VectorInsertStridedSliceOpConvert final
203  : public OpConversionPattern<vector::InsertStridedSliceOp> {
205 
207  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  Value srcVector = adaptor.getOperands().front();
210  Value dstVector = adaptor.getOperands().back();
211 
212  // Insert scalar values not supported yet.
213  if (srcVector.getType().isa<spirv::ScalarType>() ||
214  dstVector.getType().isa<spirv::ScalarType>())
215  return failure();
216 
217  uint64_t stride = getFirstIntValue(insertOp.strides());
218  if (stride != 1)
219  return failure();
220 
221  uint64_t totalSize =
222  dstVector.getType().cast<VectorType>().getNumElements();
223  uint64_t insertSize =
224  srcVector.getType().cast<VectorType>().getNumElements();
225  uint64_t offset = getFirstIntValue(insertOp.offsets());
226 
227  SmallVector<int32_t, 2> indices(totalSize);
228  std::iota(indices.begin(), indices.end(), 0);
229  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
230  totalSize);
231 
232  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233  insertOp, dstVector.getType(), dstVector, srcVector,
234  rewriter.getI32ArrayAttr(indices));
235 
236  return success();
237  }
238 };
239 
240 } // namespace
241 
243  RewritePatternSet &patterns) {
244  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
245  VectorExtractElementOpConvert, VectorExtractOpConvert,
246  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
247  VectorInsertElementOpConvert, VectorInsertOpConvert,
248  VectorInsertStridedSliceOpConvert>(typeConverter,
249  patterns.getContext());
250 }
Include the generated interface declarations.
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:215
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:637
This class implements a pattern rewriter for use with ConversionPatterns.
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops...
bool isa() const
Definition: Types.h:234
MLIRContext * getContext() const
Definition: PatternMatch.h:906
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:250
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:97