MLIR  16.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 /// Gets the first integer value from `attr`, assuming it is an integer array
30 /// attribute.
31 static uint64_t getFirstIntValue(ArrayAttr attr) {
32  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
33 }
34 
35 namespace {
36 
37 struct VectorBitcastConvert final
38  : public OpConversionPattern<vector::BitCastOp> {
40 
42  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
43  ConversionPatternRewriter &rewriter) const override {
44  Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
45  if (!dstType)
46  return failure();
47 
48  if (dstType == adaptor.getSource().getType())
49  rewriter.replaceOp(bitcastOp, adaptor.getSource());
50  else
51  rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
52  adaptor.getSource());
53 
54  return success();
55  }
56 };
57 
58 struct VectorBroadcastConvert final
59  : public OpConversionPattern<vector::BroadcastOp> {
61 
63  matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const override {
65  Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
66  if (!resultType)
67  return failure();
68 
69  if (resultType.isa<spirv::ScalarType>()) {
70  rewriter.replaceOp(castOp, adaptor.getSource());
71  return success();
72  }
73 
74  SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
75  adaptor.getSource());
76  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
77  castOp, castOp.getVectorType(), source);
78  return success();
79  }
80 };
81 
82 struct VectorExtractOpConvert final
83  : public OpConversionPattern<vector::ExtractOp> {
85 
87  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override {
89  // Only support extracting a scalar value now.
90  VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
91  if (resultVectorType && resultVectorType.getNumElements() > 1)
92  return failure();
93 
94  Type dstType = getTypeConverter()->convertType(extractOp.getType());
95  if (!dstType)
96  return failure();
97 
98  if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
99  rewriter.replaceOp(extractOp, adaptor.getVector());
100  return success();
101  }
102 
103  int32_t id = getFirstIntValue(extractOp.getPosition());
104  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
105  extractOp, adaptor.getVector(), id);
106  return success();
107  }
108 };
109 
110 struct VectorExtractStridedSliceOpConvert final
111  : public OpConversionPattern<vector::ExtractStridedSliceOp> {
113 
115  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
116  ConversionPatternRewriter &rewriter) const override {
117  Type dstType = getTypeConverter()->convertType(extractOp.getType());
118  if (!dstType)
119  return failure();
120 
121  uint64_t offset = getFirstIntValue(extractOp.getOffsets());
122  uint64_t size = getFirstIntValue(extractOp.getSizes());
123  uint64_t stride = getFirstIntValue(extractOp.getStrides());
124  if (stride != 1)
125  return failure();
126 
127  Value srcVector = adaptor.getOperands().front();
128 
129  // Extract vector<1xT> case.
130  if (dstType.isa<spirv::ScalarType>()) {
131  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
132  srcVector, offset);
133  return success();
134  }
135 
136  SmallVector<int32_t, 2> indices(size);
137  std::iota(indices.begin(), indices.end(), offset);
138 
139  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
140  extractOp, dstType, srcVector, srcVector,
141  rewriter.getI32ArrayAttr(indices));
142 
143  return success();
144  }
145 };
146 
147 template <class SPIRVFMAOp>
148 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
150 
152  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
153  ConversionPatternRewriter &rewriter) const override {
154  Type dstType = getTypeConverter()->convertType(fmaOp.getType());
155  if (!dstType)
156  return failure();
157  rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
158  adaptor.getRhs(), adaptor.getAcc());
159  return success();
160  }
161 };
162 
163 struct VectorInsertOpConvert final
164  : public OpConversionPattern<vector::InsertOp> {
166 
168  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
169  ConversionPatternRewriter &rewriter) const override {
170  // Special case for inserting scalar values into size-1 vectors.
171  if (insertOp.getSourceType().isIntOrFloat() &&
172  insertOp.getDestVectorType().getNumElements() == 1) {
173  rewriter.replaceOp(insertOp, adaptor.getSource());
174  return success();
175  }
176 
177  if (insertOp.getSourceType().isa<VectorType>() ||
178  !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
179  return failure();
180  int32_t id = getFirstIntValue(insertOp.getPosition());
181  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
182  insertOp, adaptor.getSource(), adaptor.getDest(), id);
183  return success();
184  }
185 };
186 
187 struct VectorExtractElementOpConvert final
188  : public OpConversionPattern<vector::ExtractElementOp> {
190 
192  matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
193  ConversionPatternRewriter &rewriter) const override {
194  Type vectorType =
195  getTypeConverter()->convertType(adaptor.getVector().getType());
196  if (!vectorType)
197  return failure();
198 
199  if (vectorType.isa<spirv::ScalarType>()) {
200  rewriter.replaceOp(extractOp, adaptor.getVector());
201  return success();
202  }
203 
204  rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
205  extractOp, extractOp.getType(), adaptor.getVector(),
206  extractOp.getPosition());
207  return success();
208  }
209 };
210 
211 struct VectorInsertElementOpConvert final
212  : public OpConversionPattern<vector::InsertElementOp> {
214 
216  matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override {
218  Type vectorType = getTypeConverter()->convertType(insertOp.getType());
219  if (!vectorType)
220  return failure();
221 
222  if (vectorType.isa<spirv::ScalarType>()) {
223  rewriter.replaceOp(insertOp, adaptor.getSource());
224  return success();
225  }
226 
227  rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
228  insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
229  insertOp.getPosition());
230  return success();
231  }
232 };
233 
234 struct VectorInsertStridedSliceOpConvert final
235  : public OpConversionPattern<vector::InsertStridedSliceOp> {
237 
239  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
240  ConversionPatternRewriter &rewriter) const override {
241  Value srcVector = adaptor.getOperands().front();
242  Value dstVector = adaptor.getOperands().back();
243 
244  uint64_t stride = getFirstIntValue(insertOp.getStrides());
245  if (stride != 1)
246  return failure();
247  uint64_t offset = getFirstIntValue(insertOp.getOffsets());
248 
249  if (srcVector.getType().isa<spirv::ScalarType>()) {
250  assert(!dstVector.getType().isa<spirv::ScalarType>());
251  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
252  insertOp, dstVector.getType(), srcVector, dstVector,
253  rewriter.getI32ArrayAttr(offset));
254  return success();
255  }
256 
257  uint64_t totalSize =
258  dstVector.getType().cast<VectorType>().getNumElements();
259  uint64_t insertSize =
260  srcVector.getType().cast<VectorType>().getNumElements();
261 
262  SmallVector<int32_t, 2> indices(totalSize);
263  std::iota(indices.begin(), indices.end(), 0);
264  std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
265  totalSize);
266 
267  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
268  insertOp, dstVector.getType(), dstVector, srcVector,
269  rewriter.getI32ArrayAttr(indices));
270 
271  return success();
272  }
273 };
274 
275 template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
276  class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
277 struct VectorReductionPattern final
278  : public OpConversionPattern<vector::ReductionOp> {
280 
282  matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
283  ConversionPatternRewriter &rewriter) const override {
284  Type resultType = typeConverter->convertType(reduceOp.getType());
285  if (!resultType)
286  return failure();
287 
288  auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
289  if (!srcVectorType || srcVectorType.getRank() != 1)
290  return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
291 
292  // Extract all elements.
293  int numElements = srcVectorType.getDimSize(0);
294  SmallVector<Value, 4> values;
295  values.reserve(numElements + (adaptor.getAcc() != nullptr));
296  Location loc = reduceOp.getLoc();
297  for (int i = 0; i < numElements; ++i) {
298  values.push_back(rewriter.create<spirv::CompositeExtractOp>(
299  loc, srcVectorType.getElementType(), adaptor.getVector(),
300  rewriter.getI32ArrayAttr({i})));
301  }
302  if (Value acc = adaptor.getAcc())
303  values.push_back(acc);
304 
305  // Reduce them.
306  Value result = values.front();
307  for (Value next : llvm::makeArrayRef(values).drop_front()) {
308  switch (reduceOp.getKind()) {
309 
310 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
311  case vector::CombiningKind::kind: \
312  if (resultType.isa<IntegerType>()) { \
313  result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
314  } else { \
315  assert(resultType.isa<FloatType>()); \
316  result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
317  } \
318  break
319 
320 #define INT_OR_FLOAT_CASE(kind, fop) \
321  case vector::CombiningKind::kind: \
322  result = rewriter.create<fop>(loc, resultType, result, next); \
323  break
324 
325  INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
326  INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
327 
328  INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
329  INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
330  INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
331  INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
332  INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
333  INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
334 
335  case vector::CombiningKind::AND:
336  case vector::CombiningKind::OR:
337  case vector::CombiningKind::XOR:
338  return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
339  }
340  }
341 
342  rewriter.replaceOp(reduceOp, result);
343  return success();
344  }
345 };
346 
347 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
348 public:
350 
352  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override {
354  Type dstType = getTypeConverter()->convertType(op.getType());
355  if (!dstType)
356  return failure();
357  if (dstType.isa<spirv::ScalarType>()) {
358  rewriter.replaceOp(op, adaptor.getInput());
359  } else {
360  auto dstVecType = dstType.cast<VectorType>();
361  SmallVector<Value, 4> source(dstVecType.getNumElements(),
362  adaptor.getInput());
363  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
364  source);
365  }
366  return success();
367  }
368 };
369 
370 struct VectorShuffleOpConvert final
371  : public OpConversionPattern<vector::ShuffleOp> {
373 
375  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
376  ConversionPatternRewriter &rewriter) const override {
377  auto oldResultType = shuffleOp.getVectorType();
378  if (!spirv::CompositeType::isValid(oldResultType))
379  return failure();
380  Type newResultType = getTypeConverter()->convertType(oldResultType);
381 
382  auto oldSourceType = shuffleOp.getV1VectorType();
383  if (oldSourceType.getNumElements() > 1) {
384  SmallVector<int32_t, 4> components = llvm::to_vector<4>(
385  llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
386  return attr.cast<IntegerAttr>().getValue().getZExtValue();
387  }));
388  rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
389  shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
390  rewriter.getI32ArrayAttr(components));
391  return success();
392  }
393 
394  SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
395  SmallVector<Value, 4> newOperands;
396  newOperands.reserve(oldResultType.getNumElements());
397  for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
398  newOperands.push_back(oldOperands[i.getZExtValue()]);
399  }
400  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
401  shuffleOp, newResultType, newOperands);
402 
403  return success();
404  }
405 };
406 
407 } // namespace
408 #define CL_MAX_MIN_OPS \
409  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
410  spirv::CLSMaxOp, spirv::CLSMinOp
411 
412 #define GL_MAX_MIN_OPS \
413  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
414  spirv::GLSMaxOp, spirv::GLSMinOp
415 
417  RewritePatternSet &patterns) {
418  patterns.add<
419  VectorBitcastConvert, VectorBroadcastConvert,
420  VectorExtractElementOpConvert, VectorExtractOpConvert,
421  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
422  VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
423  VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
424  VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
425  VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
426  patterns.getContext());
427 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
#define INT_AND_FLOAT_CASE(kind, iop, fop)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:253
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
#define INT_OR_FLOAT_CASE(kind, fop)
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1174
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
This class implements a pattern rewriter for use with ConversionPatterns.
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops...
bool isa() const
Definition: Types.h:259
MLIRContext * getContext() const
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:279
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:99