MLIR 23.0.0git
VectorToSPIRV.cpp
Go to the documentation of this file.
1//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Vector dialect to SPIRV dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Location.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
33#include <cassert>
34#include <cstdint>
35#include <numeric>
36
37using namespace mlir;
38
39/// Returns the integer value from the first valid input element, assuming Value
40/// inputs are defined by a constant index ops and Attribute inputs are integer
41/// attributes.
42static uint64_t getFirstIntValue(ArrayAttr attr) {
43 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
44}
45
46/// Returns the number of bits for the given scalar/vector type.
47static int getNumBits(Type type) {
48 // TODO: This does not take into account any memory layout or widening
49 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
50 // though in practice it will likely be stored as in a 4xi64 vector register.
51 if (auto vectorType = dyn_cast<VectorType>(type))
52 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
53 return type.getIntOrFloatBitWidth();
54}
55
56namespace {
57
58struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
59 using Base::Base;
60
61 LogicalResult
62 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter) const override {
64 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
65 if (!dstType)
66 return failure();
67
68 // If dstType is same as the source type or the vector size is 1, it can be
69 // directly replaced by the source.
70 if (dstType == adaptor.getSource().getType() ||
71 shapeCastOp.getResultVectorType().getNumElements() == 1) {
72 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
73 return success();
74 }
75
76 // Lowering for size-n vectors when n > 1 hasn't been implemented.
77 return failure();
78 }
79};
80
81struct VectorBitcastConvert final
82 : public OpConversionPattern<vector::BitCastOp> {
83 using Base::Base;
84
85 LogicalResult
86 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter) const override {
88 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
89 if (!dstType)
90 return failure();
91
92 if (dstType == adaptor.getSource().getType()) {
93 rewriter.replaceOp(bitcastOp, adaptor.getSource());
94 return success();
95 }
96
97 // Check that the source and destination type have the same bitwidth.
98 // Depending on the target environment, we may need to emulate certain
99 // types, which can cause issue with bitcast.
100 Type srcType = adaptor.getSource().getType();
101 if (getNumBits(dstType) != getNumBits(srcType)) {
102 return rewriter.notifyMatchFailure(
103 bitcastOp,
104 llvm::formatv("different source ({0}) and target ({1}) bitwidth",
105 srcType, dstType));
106 }
107
108 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
109 adaptor.getSource());
110 return success();
111 }
112};
113
114struct VectorBroadcastConvert final
115 : public OpConversionPattern<vector::BroadcastOp> {
116 using Base::Base;
117
118 LogicalResult
119 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override {
121 Type resultType =
122 getTypeConverter()->convertType(castOp.getResultVectorType());
123 if (!resultType)
124 return failure();
125
126 if (isa<spirv::ScalarType>(resultType)) {
127 rewriter.replaceOp(castOp, adaptor.getSource());
128 return success();
129 }
130
131 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
132 adaptor.getSource());
133 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
134 source);
135 return success();
136 }
137};
138
139// SPIR-V does not have a concept of a poison index for certain instructions,
140// which creates a UB hazard when lowering from otherwise equivalent Vector
141// dialect instructions, because this index will be considered out-of-bounds.
142// To avoid this, this function implements a dynamic sanitization that returns
143// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
144// (presumably more efficient), and otherwise index 0 (always in-bounds).
145static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
146 Location loc, Value dynamicIndex,
147 int64_t kPoisonIndex, unsigned vectorSize) {
148 if (llvm::isPowerOf2_32(vectorSize)) {
149 Value inBoundsMask = spirv::ConstantOp::create(
150 rewriter, loc, dynamicIndex.getType(),
151 rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
152 return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
153 inBoundsMask);
154 }
155 Value poisonIndex = spirv::ConstantOp::create(
156 rewriter, loc, dynamicIndex.getType(),
157 rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
158 Value cmpResult =
159 spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
160 return spirv::SelectOp::create(
161 rewriter, loc, cmpResult,
162 spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
163 dynamicIndex);
164}
165
166struct VectorExtractOpConvert final
167 : public OpConversionPattern<vector::ExtractOp> {
168 using Base::Base;
169
170 LogicalResult
171 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter) const override {
173 Type dstType = getTypeConverter()->convertType(extractOp.getType());
174 if (!dstType)
175 return failure();
176
177 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
178 rewriter.replaceOp(extractOp, adaptor.getSource());
179 return success();
180 }
181
182 if (std::optional<int64_t> id =
183 getConstantIntValue(extractOp.getMixedPosition()[0])) {
184 if (id == vector::ExtractOp::kPoisonIndex)
185 return rewriter.notifyMatchFailure(
186 extractOp,
187 "Static use of poison index handled elsewhere (folded to poison)");
188 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
189 extractOp, dstType, adaptor.getSource(),
190 rewriter.getI32ArrayAttr(id.value()));
191 } else {
192 Value sanitizedIndex = sanitizeDynamicIndex(
193 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
194 vector::ExtractOp::kPoisonIndex,
195 extractOp.getSourceVectorType().getNumElements());
196 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
197 extractOp, dstType, adaptor.getSource(), sanitizedIndex);
198 }
199 return success();
200 }
201};
202
203struct VectorExtractStridedSliceOpConvert final
204 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
205 using Base::Base;
206
207 LogicalResult
208 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const override {
210 Type dstType = getTypeConverter()->convertType(extractOp.getType());
211 if (!dstType)
212 return failure();
213
214 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
215 uint64_t size = getFirstIntValue(extractOp.getSizes());
216 uint64_t stride = getFirstIntValue(extractOp.getStrides());
217 if (stride != 1)
218 return failure();
219
220 Value srcVector = adaptor.getOperands().front();
221
222 // Extract vector<1xT> case.
223 if (isa<spirv::ScalarType>(dstType)) {
224 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
225 srcVector, offset);
226 return success();
227 }
228
229 SmallVector<int32_t, 2> indices(size);
230 std::iota(indices.begin(), indices.end(), offset);
231
232 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233 extractOp, dstType, srcVector, srcVector,
234 rewriter.getI32ArrayAttr(indices));
235
236 return success();
237 }
238};
239
240template <class SPIRVFMAOp>
241struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
242 using Base::Base;
243
244 LogicalResult
245 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
248 if (!dstType)
249 return failure();
250 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
251 adaptor.getRhs(), adaptor.getAcc());
252 return success();
253 }
254};
255
256struct VectorFromElementsOpConvert final
257 : public OpConversionPattern<vector::FromElementsOp> {
258 using Base::Base;
259
260 LogicalResult
261 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 Type resultType = getTypeConverter()->convertType(op.getType());
264 if (!resultType)
265 return failure();
266 ValueRange elements = adaptor.getElements();
267 if (isa<spirv::ScalarType>(resultType)) {
268 // In the case with a single scalar operand / single-element result,
269 // pass through the scalar.
270 rewriter.replaceOp(op, elements[0]);
271 return success();
272 }
273 // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
274 // vector.from_elements cases should not need to be handled, only 1d.
275 assert(cast<VectorType>(resultType).getRank() == 1);
276 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
277 elements);
278 return success();
279 }
280};
281
282struct VectorInsertOpConvert final
283 : public OpConversionPattern<vector::InsertOp> {
284 using Base::Base;
285
286 LogicalResult
287 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 if (isa<VectorType>(insertOp.getValueToStoreType()))
290 return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
291 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
292 return rewriter.notifyMatchFailure(insertOp,
293 "unsupported dest vector type");
294
295 // Special case for inserting scalar values into size-1 vectors.
296 if (insertOp.getValueToStoreType().isIntOrFloat() &&
297 insertOp.getDestVectorType().getNumElements() == 1) {
298 rewriter.replaceOp(insertOp, adaptor.getValueToStore());
299 return success();
300 }
301
302 if (std::optional<int64_t> id =
303 getConstantIntValue(insertOp.getMixedPosition()[0])) {
304 if (id == vector::InsertOp::kPoisonIndex)
305 return rewriter.notifyMatchFailure(
306 insertOp,
307 "Static use of poison index handled elsewhere (folded to poison)");
308 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
309 insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
310 } else {
311 Value sanitizedIndex = sanitizeDynamicIndex(
312 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
313 vector::InsertOp::kPoisonIndex,
314 insertOp.getDestVectorType().getNumElements());
315 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
316 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
317 sanitizedIndex);
318 }
319 return success();
320 }
321};
322
323struct VectorInsertStridedSliceOpConvert final
324 : public OpConversionPattern<vector::InsertStridedSliceOp> {
325 using Base::Base;
326
327 LogicalResult
328 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 Value srcVector = adaptor.getOperands().front();
331 Value dstVector = adaptor.getOperands().back();
332
333 uint64_t stride = getFirstIntValue(insertOp.getStrides());
334 if (stride != 1)
335 return failure();
336 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
337
338 if (isa<spirv::ScalarType>(srcVector.getType())) {
339 assert(!isa<spirv::ScalarType>(dstVector.getType()));
340 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
341 insertOp, dstVector.getType(), srcVector, dstVector,
342 rewriter.getI32ArrayAttr(offset));
343 return success();
344 }
345
346 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
347 uint64_t insertSize =
348 cast<VectorType>(srcVector.getType()).getNumElements();
349
350 SmallVector<int32_t, 2> indices(totalSize);
351 std::iota(indices.begin(), indices.end(), 0);
352 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
353 totalSize);
354
355 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
356 insertOp, dstVector.getType(), dstVector, srcVector,
357 rewriter.getI32ArrayAttr(indices));
358
359 return success();
360 }
361};
362
363static SmallVector<Value> extractAllElements(
364 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
365 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
366 int numElements = static_cast<int>(srcVectorType.getDimSize(0));
367 SmallVector<Value> values;
368 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
369 Location loc = reduceOp.getLoc();
370
371 for (int i = 0; i < numElements; ++i) {
372 values.push_back(spirv::CompositeExtractOp::create(
373 rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
374 rewriter.getI32ArrayAttr({i})));
375 }
376 if (Value acc = adaptor.getAcc())
377 values.push_back(acc);
378
379 return values;
380}
381
382struct ReductionRewriteInfo {
383 Type resultType;
384 SmallVector<Value> extractedElements;
385};
386
387FailureOr<ReductionRewriteInfo> static getReductionInfo(
388 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
389 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
390 Type resultType = typeConverter.convertType(op.getType());
391 if (!resultType)
392 return failure();
393
394 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
395 if (!srcVectorType || srcVectorType.getRank() != 1)
396 return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
397
398 SmallVector<Value> extractedElements =
399 extractAllElements(op, adaptor, srcVectorType, rewriter);
400
401 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
402}
403
404template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
405 typename SPIRVSMinOp>
406struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
407 using Base::Base;
408
409 LogicalResult
410 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter) const override {
412 auto reductionInfo =
413 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
414 if (failed(reductionInfo))
415 return failure();
416
417 auto [resultType, extractedElements] = *reductionInfo;
418 Location loc = reduceOp->getLoc();
419
420 // Handle boolean reductions with spirv.Any / spirv.All.
421 if (resultType.isInteger(1)) {
422 vector::CombiningKind kind = reduceOp.getKind();
423
424 if (kind == vector::CombiningKind::OR) {
425 Value result = spirv::AnyOp::create(rewriter, loc, resultType,
426 adaptor.getVector());
427 if (Value acc = adaptor.getAcc())
428 result = spirv::LogicalOrOp::create(rewriter, loc, resultType, result,
429 acc);
430 rewriter.replaceOp(reduceOp, result);
431 return success();
432 }
433
434 if (kind == vector::CombiningKind::AND) {
435 Value result = spirv::AllOp::create(rewriter, loc, resultType,
436 adaptor.getVector());
437 if (Value acc = adaptor.getAcc())
438 result = spirv::LogicalAndOp::create(rewriter, loc, resultType,
439 result, acc);
440 rewriter.replaceOp(reduceOp, result);
441 return success();
442 }
443 }
444
445 Value result = extractedElements.front();
446 for (Value next : llvm::drop_begin(extractedElements)) {
447 switch (reduceOp.getKind()) {
448
449#define INT_AND_FLOAT_CASE(kind, iop, fop) \
450 case vector::CombiningKind::kind: \
451 if (isa<IntegerType>(resultType)) { \
452 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
453 } else { \
454 assert(isa<FloatType>(resultType)); \
455 result = spirv::fop::create(rewriter, loc, resultType, result, next); \
456 } \
457 break
458
459#define INT_OR_FLOAT_CASE(kind, fop) \
460 case vector::CombiningKind::kind: \
461 result = fop::create(rewriter, loc, resultType, result, next); \
462 break
463
464#define INT_CASE(kind, iop) \
465 case vector::CombiningKind::kind: \
466 assert(isa<IntegerType>(resultType)); \
467 result = spirv::iop::create(rewriter, loc, resultType, result, next); \
468 break
469
470 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
471 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
472 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
473 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
474 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
475 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
476 INT_CASE(AND, BitwiseAndOp);
477 INT_CASE(OR, BitwiseOrOp);
478 INT_CASE(XOR, BitwiseXorOp);
479
480 default:
481 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
482 }
483#undef INT_AND_FLOAT_CASE
484#undef INT_OR_FLOAT_CASE
485#undef INT_CASE
486 }
487
488 rewriter.replaceOp(reduceOp, result);
489 return success();
490 }
491};
492
493template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
494struct VectorReductionFloatMinMax final
495 : OpConversionPattern<vector::ReductionOp> {
496 using Base::Base;
497
498 LogicalResult
499 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter) const override {
501 auto reductionInfo =
502 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
503 if (failed(reductionInfo))
504 return failure();
505
506 auto [resultType, extractedElements] = *reductionInfo;
507 Location loc = reduceOp->getLoc();
508 Value result = extractedElements.front();
509 for (Value next : llvm::drop_begin(extractedElements)) {
510 switch (reduceOp.getKind()) {
511
512#define INT_OR_FLOAT_CASE(kind, fop) \
513 case vector::CombiningKind::kind: \
514 result = fop::create(rewriter, loc, resultType, result, next); \
515 break
516
517 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
518 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
519 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
520 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
521
522 default:
523 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
524 }
525#undef INT_OR_FLOAT_CASE
526 }
527
528 rewriter.replaceOp(reduceOp, result);
529 return success();
530 }
531};
532
533class VectorScalarBroadcastPattern final
534 : public OpConversionPattern<vector::BroadcastOp> {
535public:
536 using Base::Base;
537
538 LogicalResult
539 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 if (isa<VectorType>(op.getSourceType())) {
542 return rewriter.notifyMatchFailure(
543 op, "only conversion of 'broadcast from scalar' is supported");
544 }
545 Type dstType = getTypeConverter()->convertType(op.getType());
546 if (!dstType)
547 return failure();
548 if (isa<spirv::ScalarType>(dstType)) {
549 rewriter.replaceOp(op, adaptor.getSource());
550 } else {
551 auto dstVecType = cast<VectorType>(dstType);
552 SmallVector<Value, 4> source(dstVecType.getNumElements(),
553 adaptor.getSource());
554 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
555 source);
556 }
557 return success();
558 }
559};
560
561struct VectorShuffleOpConvert final
562 : public OpConversionPattern<vector::ShuffleOp> {
563 using Base::Base;
564
565 LogicalResult
566 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
567 ConversionPatternRewriter &rewriter) const override {
568 VectorType oldResultType = shuffleOp.getResultVectorType();
569 Type newResultType = getTypeConverter()->convertType(oldResultType);
570 if (!newResultType)
571 return rewriter.notifyMatchFailure(shuffleOp,
572 "unsupported result vector type");
573
574 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
575
576 VectorType oldV1Type = shuffleOp.getV1VectorType();
577 VectorType oldV2Type = shuffleOp.getV2VectorType();
578
579 // When both operands and the result are SPIR-V vectors, emit a SPIR-V
580 // shuffle.
581 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
582 oldResultType.getNumElements() > 1) {
583 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
584 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
585 rewriter.getI32ArrayAttr(mask));
586 return success();
587 }
588
589 // When at least one of the operands or the result becomes a scalar after
590 // type conversion for SPIR-V, extract all the required elements and
591 // construct the result vector.
592 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
593 Value scalarOrVec, int32_t idx) -> Value {
594 if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
595 return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
596 idx);
597
598 assert(idx == 0 && "Invalid scalar element index");
599 return scalarOrVec;
600 };
601
602 int32_t numV1Elems = oldV1Type.getNumElements();
603 SmallVector<Value> newOperands(mask.size());
604 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
605 Value vec = adaptor.getV1();
606 int32_t elementIdx = shuffleIdx;
607 if (elementIdx >= numV1Elems) {
608 vec = adaptor.getV2();
609 elementIdx -= numV1Elems;
610 }
611
612 newOperand = getElementAtIdx(vec, elementIdx);
613 }
614
615 // Handle the scalar result corner case.
616 if (newOperands.size() == 1) {
617 rewriter.replaceOp(shuffleOp, newOperands.front());
618 return success();
619 }
620
621 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
622 shuffleOp, newResultType, newOperands);
623 return success();
624 }
625};
626
627struct VectorInterleaveOpConvert final
628 : public OpConversionPattern<vector::InterleaveOp> {
629 using Base::Base;
630
631 LogicalResult
632 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 // Check the result vector type.
635 VectorType oldResultType = interleaveOp.getResultVectorType();
636 Type newResultType = getTypeConverter()->convertType(oldResultType);
637 if (!newResultType)
638 return rewriter.notifyMatchFailure(interleaveOp,
639 "unsupported result vector type");
640
641 // Interleave the indices.
642 VectorType sourceType = interleaveOp.getSourceVectorType();
643 int n = sourceType.getNumElements();
644
645 // Input vectors of size 1 are converted to scalars by the type converter.
646 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
647 // use `spirv::CompositeConstructOp`.
648 if (n == 1) {
649 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
650 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
651 interleaveOp, newResultType, newOperands);
652 return success();
653 }
654
655 auto seq = llvm::seq<int64_t>(2 * n);
656 auto indices = llvm::map_to_vector(
657 seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
658
659 // Emit a SPIR-V shuffle.
660 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
661 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
662 rewriter.getI32ArrayAttr(indices));
663
664 return success();
665 }
666};
667
668struct VectorDeinterleaveOpConvert final
669 : public OpConversionPattern<vector::DeinterleaveOp> {
670 using Base::Base;
671
672 LogicalResult
673 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter) const override {
675
676 // Check the result vector type.
677 VectorType oldResultType = deinterleaveOp.getResultVectorType();
678 Type newResultType = getTypeConverter()->convertType(oldResultType);
679 if (!newResultType)
680 return rewriter.notifyMatchFailure(deinterleaveOp,
681 "unsupported result vector type");
682
683 Location loc = deinterleaveOp->getLoc();
684
685 // Deinterleave the indices.
686 Value sourceVector = adaptor.getSource();
687 VectorType sourceType = deinterleaveOp.getSourceVectorType();
688 int n = sourceType.getNumElements();
689
690 // Output vectors of size 1 are converted to scalars by the type converter.
691 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
692 // use `spirv::CompositeExtractOp`.
693 if (n == 2) {
694 auto elem0 = spirv::CompositeExtractOp::create(
695 rewriter, loc, newResultType, sourceVector,
696 rewriter.getI32ArrayAttr({0}));
697
698 auto elem1 = spirv::CompositeExtractOp::create(
699 rewriter, loc, newResultType, sourceVector,
700 rewriter.getI32ArrayAttr({1}));
701
702 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
703 return success();
704 }
705
706 // Indices for `shuffleEven` (result 0).
707 auto seqEven = llvm::seq<int64_t>(n / 2);
708 auto indicesEven =
709 llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
710
711 // Indices for `shuffleOdd` (result 1).
712 auto seqOdd = llvm::seq<int64_t>(n / 2);
713 auto indicesOdd =
714 llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
715
716 // Create two SPIR-V shuffles.
717 auto shuffleEven = spirv::VectorShuffleOp::create(
718 rewriter, loc, newResultType, sourceVector, sourceVector,
719 rewriter.getI32ArrayAttr(indicesEven));
720
721 auto shuffleOdd = spirv::VectorShuffleOp::create(
722 rewriter, loc, newResultType, sourceVector, sourceVector,
723 rewriter.getI32ArrayAttr(indicesOdd));
724
725 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
726 return success();
727 }
728};
729
730struct VectorLoadOpConverter final
731 : public OpConversionPattern<vector::LoadOp> {
732 using Base::Base;
733
734 LogicalResult
735 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 auto memrefType = loadOp.getMemRefType();
738 auto attr =
739 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
740 if (!attr)
741 return rewriter.notifyMatchFailure(
742 loadOp, "expected spirv.storage_class memory space");
743
744 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
745 auto loc = loadOp.getLoc();
746 Value accessChain =
747 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
748 adaptor.getIndices(), loc, rewriter);
749 if (!accessChain)
750 return rewriter.notifyMatchFailure(
751 loadOp, "failed to get memref element pointer");
752
753 spirv::StorageClass storageClass = attr.getValue();
754 auto vectorType = loadOp.getVectorType();
755 // Use the converted vector type instead of original (single element vector
756 // would get converted to scalar).
757 auto spirvVectorType = typeConverter.convertType(vectorType);
758 if (!spirvVectorType)
759 return rewriter.notifyMatchFailure(loadOp, "unsupported vector type");
760
761 auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
762
763 std::optional<uint64_t> alignment = loadOp.getAlignment();
764 if (alignment > std::numeric_limits<uint32_t>::max()) {
765 return rewriter.notifyMatchFailure(loadOp,
766 "invalid alignment requirement");
767 }
768
769 auto memoryAccess = spirv::MemoryAccess::None;
770 spirv::MemoryAccessAttr memoryAccessAttr;
771 IntegerAttr alignmentAttr;
772 if (alignment.has_value()) {
773 memoryAccess |= spirv::MemoryAccess::Aligned;
774 memoryAccessAttr =
775 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
776 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
777 }
778
779 // For single element vectors, we don't need to bitcast the access chain to
780 // the original vector type. Both is going to be the same, a pointer
781 // to a scalar.
782 Value castedAccessChain =
783 (vectorType.getNumElements() == 1)
784 ? accessChain
785 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
786 accessChain);
787
788 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
789 castedAccessChain,
790 memoryAccessAttr, alignmentAttr);
791
792 return success();
793 }
794};
795
796struct VectorStoreOpConverter final
797 : public OpConversionPattern<vector::StoreOp> {
798 using Base::Base;
799
800 LogicalResult
801 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
802 ConversionPatternRewriter &rewriter) const override {
803 auto memrefType = storeOp.getMemRefType();
804 auto attr =
805 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
806 if (!attr)
807 return rewriter.notifyMatchFailure(
808 storeOp, "expected spirv.storage_class memory space");
809
810 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
811 auto loc = storeOp.getLoc();
812 Value accessChain =
813 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
814 adaptor.getIndices(), loc, rewriter);
815 if (!accessChain)
816 return rewriter.notifyMatchFailure(
817 storeOp, "failed to get memref element pointer");
818
819 std::optional<uint64_t> alignment = storeOp.getAlignment();
820 if (alignment > std::numeric_limits<uint32_t>::max()) {
821 return rewriter.notifyMatchFailure(storeOp,
822 "invalid alignment requirement");
823 }
824
825 spirv::StorageClass storageClass = attr.getValue();
826 auto vectorType = storeOp.getVectorType();
827 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
828
829 // For single element vectors, we don't need to bitcast the access chain to
830 // the original vector type. Both is going to be the same, a pointer
831 // to a scalar.
832 Value castedAccessChain =
833 (vectorType.getNumElements() == 1)
834 ? accessChain
835 : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
836 accessChain);
837
838 auto memoryAccess = spirv::MemoryAccess::None;
839 spirv::MemoryAccessAttr memoryAccessAttr;
840 IntegerAttr alignmentAttr;
841 if (alignment.has_value()) {
842 memoryAccess |= spirv::MemoryAccess::Aligned;
843 memoryAccessAttr =
844 spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
845 alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
846 }
847
848 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
849 storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
850 alignmentAttr);
851
852 return success();
853 }
854};
855
856struct VectorReductionToIntDotProd final
857 : OpRewritePattern<vector::ReductionOp> {
858 using Base::Base;
859
860 LogicalResult matchAndRewrite(vector::ReductionOp op,
861 PatternRewriter &rewriter) const override {
862 if (op.getKind() != vector::CombiningKind::ADD)
863 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
864
865 auto resultType = dyn_cast<IntegerType>(op.getType());
866 if (!resultType)
867 return rewriter.notifyMatchFailure(op, "result is not an integer");
868
869 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
870 if (!llvm::is_contained({32, 64}, resultBitwidth))
871 return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
872
873 VectorType inVecTy = op.getSourceVectorType();
874 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
875 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
876 return rewriter.notifyMatchFailure(op, "unsupported vector shape");
877
878 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
879 if (!mul)
880 return rewriter.notifyMatchFailure(
881 op, "reduction operand is not 'arith.muli'");
882
883 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
884 spirv::SDotAccSatOp, false>(op, mul, rewriter)))
885 return success();
886
887 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
888 spirv::UDotAccSatOp, false>(op, mul, rewriter)))
889 return success();
890
891 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
892 spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
893 return success();
894
895 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
896 spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
897 return success();
898
899 return failure();
900 }
901
902private:
903 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
904 typename DotAccOp, bool SwapOperands>
905 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
906 PatternRewriter &rewriter) {
907 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
908 if (!lhs)
909 return failure();
910 Value lhsIn = lhs.getIn();
911 auto lhsInType = cast<VectorType>(lhsIn.getType());
912 if (!lhsInType.getElementType().isInteger(8))
913 return failure();
914
915 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
916 if (!rhs)
917 return failure();
918 Value rhsIn = rhs.getIn();
919 auto rhsInType = cast<VectorType>(rhsIn.getType());
920 if (!rhsInType.getElementType().isInteger(8))
921 return failure();
922
923 if (op.getSourceVectorType().getNumElements() == 3) {
924 IntegerType i8Type = rewriter.getI8Type();
925 auto v4i8Type = VectorType::get({4}, i8Type);
926 Location loc = op.getLoc();
927 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
928 lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
929 ValueRange{lhsIn, zero});
930 rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
931 ValueRange{rhsIn, zero});
932 }
933
934 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
935 // we have to swap operands instead in that case.
936 if (SwapOperands)
937 std::swap(lhsIn, rhsIn);
938
939 if (Value acc = op.getAcc()) {
940 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
941 nullptr);
942 } else {
943 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
944 nullptr);
945 }
946
947 return success();
948 }
949};
950
951struct VectorReductionToFPDotProd final
952 : OpConversionPattern<vector::ReductionOp> {
953 using Base::Base;
954
955 LogicalResult
956 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
957 ConversionPatternRewriter &rewriter) const override {
958 if (op.getKind() != vector::CombiningKind::ADD)
959 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
960
961 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
962 if (!resultType)
963 return rewriter.notifyMatchFailure(op, "result is not a float");
964
965 Value vec = adaptor.getVector();
966 Value acc = adaptor.getAcc();
967
968 auto vectorType = dyn_cast<VectorType>(vec.getType());
969 if (!vectorType) {
970 assert(isa<FloatType>(vec.getType()) &&
971 "Expected the vector to be scalarized");
972 if (acc) {
973 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
974 return success();
975 }
976
977 rewriter.replaceOp(op, vec);
978 return success();
979 }
980
981 Location loc = op.getLoc();
982 Value lhs;
983 Value rhs;
984 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
985 lhs = mul.getLhs();
986 rhs = mul.getRhs();
987 } else {
988 // If the operand is not a mul, use a vector of ones for the dot operand
989 // to just sum up all values.
990 lhs = vec;
991 Attribute oneAttr =
992 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
993 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
994 rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
995 }
996 assert(lhs);
997 assert(rhs);
998
999 Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
1000 if (acc)
1001 res = spirv::FAddOp::create(rewriter, loc, acc, res);
1002
1003 rewriter.replaceOp(op, res);
1004 return success();
1005 }
1006};
1007
1008struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
1009 using Base::Base;
1010
1011 LogicalResult
1012 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter) const override {
1014 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1015 Type dstType = typeConverter.convertType(stepOp.getType());
1016 if (!dstType)
1017 return failure();
1018
1019 Location loc = stepOp.getLoc();
1020 int64_t numElements = stepOp.getType().getNumElements();
1021 auto intType =
1022 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1023
1024 // Input vectors of size 1 are converted to scalars by the type converter.
1025 // We just create a constant in this case.
1026 if (numElements == 1) {
1027 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1028 rewriter.replaceOp(stepOp, zero);
1029 return success();
1030 }
1031
1032 SmallVector<Value> source;
1033 source.reserve(numElements);
1034 for (int64_t i = 0; i < numElements; ++i) {
1035 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1036 Value constOp =
1037 spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
1038 source.push_back(constOp);
1039 }
1040 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1041 source);
1042 return success();
1043 }
1044};
1045
1046struct VectorToElementOpConvert final
1047 : OpConversionPattern<vector::ToElementsOp> {
1048 using Base::Base;
1049
1050 LogicalResult
1051 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1052 ConversionPatternRewriter &rewriter) const override {
1053
1054 SmallVector<Value> results(toElementsOp->getNumResults());
1055 Location loc = toElementsOp.getLoc();
1056
1057 // Input vectors of size 1 are converted to scalars by the type converter.
1058 // We cannot use `spirv::CompositeExtractOp` directly in this case.
1059 // For a scalar source, the result is just the scalar itself.
1060 if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1061 results[0] = adaptor.getSource();
1062 rewriter.replaceOp(toElementsOp, results);
1063 return success();
1064 }
1065
1066 Type srcElementType = toElementsOp.getElements().getType().front();
1067 Type elementType = getTypeConverter()->convertType(srcElementType);
1068 if (!elementType)
1069 return rewriter.notifyMatchFailure(
1070 toElementsOp,
1071 llvm::formatv("failed to convert element type '{0}' to SPIR-V",
1072 srcElementType));
1073
1074 for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1075 // Create an CompositeExtract operation only for results that are not
1076 // dead.
1077 if (element.use_empty())
1078 continue;
1079
1080 Value result = spirv::CompositeExtractOp::create(
1081 rewriter, loc, elementType, adaptor.getSource(),
1082 rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1083 results[idx] = result;
1084 }
1085
1086 rewriter.replaceOp(toElementsOp, results);
1087 return success();
1088 }
1089};
1090
1091} // namespace
1092#define CL_INT_MAX_MIN_OPS \
1093 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1094
1095#define GL_INT_MAX_MIN_OPS \
1096 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1097
1098#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1099#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1100
1102 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1103 patterns.add<
1104 VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
1105 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1106 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1107 VectorToElementOpConvert, VectorInsertOpConvert,
1108 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1109 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1110 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1111 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1112 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1113 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1114 VectorScalarBroadcastPattern, VectorLoadOpConverter,
1115 VectorStoreOpConverter, VectorStepOpConvert>(
1116 typeConverter, patterns.getContext(), PatternBenefit(1));
1117
1118 // Make sure that the more specialized dot product pattern has higher benefit
1119 // than the generic one that extracts all elements.
1120 patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
1121 PatternBenefit(2));
1122}
1123
1125 RewritePatternSet &patterns) {
1126 patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
1127}
return success()
lhs
ArrayAttr()
static constexpr unsigned getNumBits()
#define MINUI(lhs, rhs)
#define INT_CASE(kind, iop)
static uint64_t getFirstIntValue(ArrayAttr attr)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
#define INT_AND_FLOAT_CASE(kind, iop, fop)
#define INT_OR_FLOAT_CASE(kind, fop)
#define mul(a, b)
IntegerType getI8Type()
Definition Builders.cpp:63
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static PointerType get(Type pointeeType, StorageClass storageClass)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populateVectorReductionToSPIRVDotProductPatterns(RewritePatternSet &patterns)
Appends patterns to convert vector reduction of the form:
void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Vector Ops to SPIR-V ops.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...