MLIR 22.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate
10// narrow types that are not supported by the target hardware, e.g. i4
11// ("emulated type"), using wider types, e.g. i8 ("container type").
12//
13/// Currently, only power-of-two integer types are supported. These are
14/// converted to wider integers that are either 8 bits wide or wider.
15///
16/// TODO: Support for non-powers-of-two.
17//===----------------------------------------------------------------------===//
18
32#include "mlir/IR/Value.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
38#include <cstdint>
39#include <optional>
40
42
43using namespace mlir;
44
45#define DEBUG_TYPE "vector-narrow-type-emulation"
46
49
50//===----------------------------------------------------------------------===//
51// Utils
52//===----------------------------------------------------------------------===//
53
54/// Returns a compressed mask for the emulated vector. For example, when
55/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
56/// elements span two dest elements), this method compresses `vector<8xi1>`
57/// into `vector<2xi1>`.
58///
59/// The compressed/output mask value is set iff any mask in the corresponding
60/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
61/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
62/// following mask:
63///
64/// %mask = [1, 1, 0, 0, 0, 0]
65///
66/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
67/// will be added in the back to make the number of elements a multiple of
68/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
69///
70/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
71///
72/// then it will return the following new compressed mask:
73///
74/// %mask = [1, 1, 0, 0]
75///
76/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
77/// `numSrcElemsPerDest`.
78static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
79 Location loc, Value mask,
80 int numSrcElems,
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
83
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
86
87 auto numDestElems =
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
89 numSrcElemsPerDest;
90
91 Operation *maskOp = mask.getDefiningOp();
93 // TODO: add support to `vector.broadcast`.
94 // Finding the mask creation operation.
95 while (maskOp &&
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
97 maskOp)) {
98 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
101 }
102 }
103
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
105 maskOp))
106 return failure();
107
108 // Computing the "compressed" mask. All the emulation logic (i.e. computing
109 // new mask index) only happens on the last dimension of the vectors.
110 SmallVector<int64_t> maskShape(
111 cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
114 std::optional<Operation *> newMask =
116 .Case<vector::CreateMaskOp>(
117 [&](auto createMaskOp) -> std::optional<Operation *> {
118 OperandRange maskOperands = createMaskOp.getOperands();
119 // The `vector.create_mask` op creates a mask arrangement
120 // without any zeros at the front. Also, because
121 // `numFrontPadElems` is strictly smaller than
122 // `numSrcElemsPerDest`, the compressed mask generated by
123 // padding the original mask by `numFrontPadElems` will not
124 // have any zeros at the front as well.
125 AffineExpr s0;
126 bindSymbols(rewriter.getContext(), s0);
127 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
128 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
130 rewriter, loc, s0, origIndex);
131 SmallVector<Value> newMaskOperands(maskOperands.drop_back());
132 newMaskOperands.push_back(
133 getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
134 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
135 newMaskOperands);
136 })
137 .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
138 -> std::optional<Operation *> {
139 // Take the shape of mask, compress its trailing dimension:
140 SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
141 int64_t &maskIndex = maskDimSizes.back();
142 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
143 numSrcElemsPerDest);
144 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
145 maskDimSizes);
146 })
147 .Case<arith::ConstantOp>([&](auto constantOp)
148 -> std::optional<Operation *> {
149 // TODO: Support multiple dimensions.
150 if (maskShape.size() != 1)
151 return std::nullopt;
152 // Rearrange the original mask values to cover the whole potential
153 // loading region. For example, in the case of using byte-size for
154 // emulation, given the following mask:
155 //
156 // %mask = [0, 1, 0, 1, 0, 0]
157 //
158 // With front offset of 1, the mask will be padded 0s in the front
159 // and back so that:
160 // 1. It is aligned with the effective loading bits
161 // 2. Its length is multiple of `numSrcElemPerDest` (and the total
162 // coverage size is mulitiple of bytes). The new mask will be like
163 // this before compressing:
164 //
165 // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
166 auto originalMask =
167 cast<DenseIntElementsAttr>(constantOp.getValue());
168 SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
169 paddedMaskValues.append(originalMask.template value_begin<bool>(),
170 originalMask.template value_end<bool>());
171 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
172
173 // Compressing by combining every `numSrcElemsPerDest` elements:
174 SmallVector<bool> compressedMaskValues;
175 for (size_t i = 0; i < paddedMaskValues.size();
176 i += numSrcElemsPerDest) {
177 bool combinedValue = false;
178 for (int j = 0; j < numSrcElemsPerDest; ++j) {
179 combinedValue |= paddedMaskValues[i + j];
180 }
181 compressedMaskValues.push_back(combinedValue);
182 }
183 return arith::ConstantOp::create(
184 rewriter, loc,
185 DenseElementsAttr::get(newMaskType, compressedMaskValues));
186 });
187
188 if (!newMask)
189 return failure();
190
191 while (!extractOps.empty()) {
192 newMask =
193 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
194 extractOps.back().getMixedPosition());
195 extractOps.pop_back();
196 }
197
198 return *newMask;
199}
200
201/// Extracts 1-D subvector from a 1-D vector.
202///
203/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204/// from `src`, starting at `offset`. The result is also a rank-1 vector:
205///
206/// vector<numElemsToExtract x !elemType>
207///
208/// (`!elType` is the element type of the source vector). As `offset` is a known
209/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
210///
211/// EXAMPLE:
212/// %res = vector.extract_strided_slice %src
213/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
215 Value src, int64_t offset,
216 int64_t numElemsToExtract) {
217 auto vectorType = cast<VectorType>(src.getType());
218 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
221
222 // When extracting all available elements, just use the source vector as the
223 // result.
224 if (vectorType.getNumElements() == numElemsToExtract)
225 return src;
226
227 auto offsets = rewriter.getI64ArrayAttr({offset});
228 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
229 auto strides = rewriter.getI64ArrayAttr({1});
230
231 auto resultVectorType =
232 VectorType::get({numElemsToExtract}, vectorType.getElementType());
233 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
234 src, offsets, sizes, strides)
235 ->getResult(0);
236}
237
238/// Inserts 1-D subvector into a 1-D vector.
239///
240/// Inserts the input rank-1 source vector into the destination vector starting
241/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
242/// `vector.insert_strided_slice`.
243///
244/// EXAMPLE:
245/// %res = vector.insert_strided_slice %src, %dest
246/// {offsets = [%offset], strides [1]}
248 Value src, Value dest, int64_t offset) {
249 [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
250 [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
251 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
252 "expected source and dest to be rank-1 vector types");
253
254 // If overwritting the destination vector, just return the source.
255 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
256 return src;
257
258 auto offsets = rewriter.getI64ArrayAttr({offset});
259 auto strides = rewriter.getI64ArrayAttr({1});
260 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
261 dest, offsets, strides);
262}
263
264/// Extracts 1-D subvector from a 1-D vector.
265///
266/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
267/// from `src`, starting at `offset`. The result is also a rank-1 vector:
268///
269/// vector<numElemsToExtact x !elType>
270///
271/// (`!elType` is the element type of the source vector). As `offset` is assumed
272/// to be a _dynamic_ SSA value, this helper method generates a sequence of
273/// `vector.extract` + `vector.insert` pairs.
274///
275/// EXAMPLE:
276/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
277/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
278/// %c1 = arith.constant 1 : index
279/// %idx2 = arith.addi %offset, %c1 : index
280/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
281/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
282/// (...)
284 Value src, Value dest,
285 OpFoldResult offset,
286 int64_t numElemsToExtract) {
287 auto srcVecTy = cast<VectorType>(src.getType());
288 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
289 // NOTE: We are unable to take the offset into account in the following
290 // assert, hence its still possible that the subvector is out-of-bounds even
291 // if the condition is true.
292 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
293 "subvector out of bounds");
294
295 // When extracting all available elements, just use the source vector as the
296 // result.
297 if (srcVecTy.getNumElements() == numElemsToExtract)
298 return src;
299
300 for (int i = 0; i < numElemsToExtract; ++i) {
301 Value extractLoc =
302 (i == 0) ? dyn_cast<Value>(offset)
303 : arith::AddIOp::create(
304 rewriter, loc, rewriter.getIndexType(),
305 dyn_cast<Value>(offset),
306 arith::ConstantIndexOp::create(rewriter, loc, i));
307 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
308 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
309 }
310 return dest;
311}
312
313/// Inserts 1-D subvector into a 1-D vector.
314///
315/// Inserts the input rank-1 source vector into the destination vector starting
316/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317/// uses a sequence of `vector.extract` + `vector.insert` pairs.
318///
319/// EXAMPLE:
320/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322/// %c1 = arith.constant 1 : index
323/// %idx2 = arith.addi %offset, %c1 : index
324/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326/// (...)
328 Value src, Value dest,
329 OpFoldResult offset,
330 int64_t numElemsToInsert) {
331 auto srcVecTy = cast<VectorType>(src.getType());
332 auto destVecTy = cast<VectorType>(dest.getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
335 (void)srcVecTy;
336 (void)destVecTy;
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
339 // NOTE: We are unable to take the offset into account in the following
340 // assert, hence its still possible that the subvector is out-of-bounds even
341 // if the condition is true.
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
344
345 Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
346 for (int64_t i = 0; i < numElemsToInsert; ++i) {
347 auto insertLoc =
348 i == 0 ? destOffsetVal
349 : arith::AddIOp::create(
350 rewriter, loc, rewriter.getIndexType(), destOffsetVal,
351 arith::ConstantIndexOp::create(rewriter, loc, i));
352 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
353 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
354 }
355 return dest;
356}
357
358/// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
359///
360/// Specifically, use `containerElemTy` for loading a vector of
361/// `emulatedElemTy`. The load location is given by `base` and
362/// `linearizedIndices`, and the load size is given by
363/// `numEmulatedElementsToLoad`.
365 Value base,
366 OpFoldResult linearizedIndices,
367 int64_t numContainerElemsToLoad,
368 Type emulatedElemTy,
369 Type containerElemTy) {
370 auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
371 emulatedElemTy.getIntOrFloatBitWidth();
372 auto newLoad = vector::LoadOp::create(
373 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
374 base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
375 return vector::BitCastOp::create(
376 rewriter, loc,
377 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
378 emulatedElemTy),
379 newLoad);
380}
381
382/// Downcast two values to `downcastType`, then select values
383/// based on `mask`, and casts the result to `upcastType`.
385 VectorType downcastType,
386 VectorType upcastType, Value mask,
387 Value trueValue, Value falseValue) {
388 assert(
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.getType() != downcastType) {
393 trueValue =
394 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
395 }
396 if (falseValue.getType() != downcastType) {
397 falseValue =
398 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
399 }
400 Value selectedType =
401 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
402 // Upcast the selected value to the new type.
403 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
404}
405
406/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
407/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
408/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
409/// which elements to store.
410///
411/// Inputs:
412/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
413/// storeIdx = 2
414/// valueToStore = |3|3|3|3| : vector<4xi2>
415/// mask = |0|0|1|1| : vector<4xi1>
416///
417/// Result:
418/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
419static void atomicRMW(OpBuilder &builder, Location loc,
420 MemRefValue linearizedMemref, Value storeIdx,
421 VectorValue valueToStore, Value mask) {
422 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
423
424 // Create an atomic load-modify-write region using
425 // `memref.generic_atomic_rmw`.
426 auto atomicOp = memref::GenericAtomicRMWOp::create(
427 builder, loc, linearizedMemref, ValueRange{storeIdx});
428 Value origValue = atomicOp.getCurrentValue();
429
430 OpBuilder::InsertionGuard guard(builder);
431 builder.setInsertionPointToStart(atomicOp.getBody());
432
433 // Load the original value from memory, and cast it to the original element
434 // type.
435 auto oneElemVecType = VectorType::get({1}, origValue.getType());
436 Value origVecValue = vector::FromElementsOp::create(
437 builder, loc, oneElemVecType, ValueRange{origValue});
438
439 // Construct the final masked value and yield it.
440 Value maskedValue =
441 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
442 oneElemVecType, mask, valueToStore, origVecValue);
443 auto scalarMaskedValue =
444 vector::ExtractOp::create(builder, loc, maskedValue, 0);
445 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
446}
447
448/// Generate a non-atomic read-modify-write sequence for storing to the emulated
449/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
450static void nonAtomicRMW(OpBuilder &builder, Location loc,
451 MemRefValue linearizedMemref, Value linearizedIndex,
452 VectorValue valueToStore, Value mask) {
453 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
454
455 auto oneElemVecType =
456 VectorType::get({1}, linearizedMemref.getType().getElementType());
457 Value origVecValue =
458 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
459 ValueRange{linearizedIndex});
460 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
461 origVecValue);
462
463 Value maskedValue =
464 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
465 oneElemVecType, mask, valueToStore, origVecValue);
466 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
467 linearizedIndex);
468}
469
470/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
471/// and insert it into an empty vector at `insertOffset`.
472/// Inputs:
473/// vec_in = |0|1|2|3| : vector<4xi2>
474/// extractOffset = 1
475/// sliceNumElements = 2
476/// insertOffset = 2
477/// Output:
478/// vec_out = |0|0|1|2| : vector<4xi2>
479static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
481 int64_t extractOffset,
482 int64_t sliceNumElements,
483 int64_t insertOffset) {
484 assert(vector.getType().getRank() == 1 && "expected 1-D vector");
485 auto vectorElementType = vector.getType().getElementType();
486 // TODO: update and use `alignedConversionPrecondition` in the place of
487 // these asserts.
488 assert(
489 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
490 "sliceNumElements * vector element size must be less than or equal to 8");
491 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
492 "vector element must be a valid sub-byte type");
493 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
494 auto emptyByteVector = arith::ConstantOp::create(
495 rewriter, loc,
496 VectorType::get({emulatedPerContainerElem}, vectorElementType),
497 rewriter.getZeroAttr(
498 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
499 auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
500 extractOffset, sliceNumElements);
501 return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
502 insertOffset);
503}
504
505namespace {
506
507//===----------------------------------------------------------------------===//
508// ConvertVectorStore
509//===----------------------------------------------------------------------===//
510
511// Emulate `vector.store` using a multi-byte container type.
512//
513// The container type is obtained through Op adaptor and would normally be
514// generated via `NarrowTypeEmulationConverter`.
515//
516// EXAMPLE 1
517// (aligned store of i4, emulated using i8 as the container type)
518//
519// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
520//
521// is rewritten as:
522//
523// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
524// vector.store %src_bitcast, %dest_bitcast[%idx]
525// : memref<16xi8>, vector<4xi8>
526//
527// EXAMPLE 2
528// (unaligned store of i2, emulated using i8 as the container type)
529//
530// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
531//
532// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
533// is modelled using 3 bytes:
534//
535// Byte 0 Byte 1 Byte 2
536// +----------+----------+----------+
537// | oooooooo | ooooNNNN | NNoooooo |
538// +----------+----------+----------+
539//
540// N - (N)ew entries (i.e. to be overwritten by vector.store)
541// o - (o)ld entries (to be preserved)
542//
543// For the generated output in the non-atomic case, see:
544// * @vector_store_i2_const_index_two_partial_stores`
545// in:
546// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
547//
548// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
549// `false` to generate non-atomic RMW sequences.
550struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
551 using Base::Base;
552
553 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
554 : OpConversionPattern<vector::StoreOp>(context),
555 disableAtomicRMW(disableAtomicRMW) {}
556
557 LogicalResult
558 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560
561 if (op.getValueToStore().getType().getRank() != 1)
562 return rewriter.notifyMatchFailure(op,
563 "only 1-D vectors are supported ATM");
564
565 auto loc = op.getLoc();
566
567 auto valueToStore = cast<VectorValue>(op.getValueToStore());
568 auto containerElemTy =
569 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
570 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
571 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
572 int containerBits = containerElemTy.getIntOrFloatBitWidth();
573
574 // Check per-element alignment.
575 if (containerBits % emulatedBits != 0) {
576 return rewriter.notifyMatchFailure(
577 op, "impossible to pack emulated elements into container elements "
578 "(bit-wise misalignment)");
579 }
580 int emulatedPerContainerElem = containerBits / emulatedBits;
581
582 // Adjust the number of elements to store when emulating narrow types.
583 // Here only the 1-D vector store is considered, and the N-D memref types
584 // should be linearized.
585 // For example, to emulate i4 to i8, the following op:
586 //
587 // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
588 //
589 // can be replaced with
590 //
591 // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
592 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
593 // vector<4xi8>
594
595 auto origElements = valueToStore.getType().getNumElements();
596 // Note, per-element-alignment was already verified above.
597 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
598 // Do the trailing dim for source and destination match? If yes, then the
599 // corresponding index must be 0.
600 // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
601 // However, that makes some tests fail, so we need to audit first.
602 auto trailingDim = op.getBase().getType().getShape().back();
603 bool trailingDimsMatch =
604 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
605
606 auto stridedMetadata =
607 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
608
609 // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
610 // non-zero. As such, this is not needed.
611 OpFoldResult linearizedIndices;
612 memref::LinearizedMemRefInfo linearizedInfo;
613 std::tie(linearizedInfo, linearizedIndices) =
615 rewriter, loc, emulatedBits, containerBits,
616 stridedMetadata.getConstifiedMixedOffset(),
617 stridedMetadata.getConstifiedMixedSizes(),
618 stridedMetadata.getConstifiedMixedStrides(),
619 getAsOpFoldResult(adaptor.getIndices()));
620
621 std::optional<int64_t> foldedNumFrontPadElems =
622 (isDivisibleInSize && trailingDimsMatch)
623 ? 0
624 : getConstantIntValue(linearizedInfo.intraDataOffset);
625
626 if (!foldedNumFrontPadElems) {
627 return rewriter.notifyMatchFailure(
628 op, "subbyte store emulation: dynamic front padding size is "
629 "not yet implemented");
630 }
631
632 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
633
634 // RMWs are not needed when:
635 // * no _partial_ stores are required.
636 // A partial store is defined as a store in which only a part of the
637 // container element is overwritten, e.g.
638 //
639 // Dest before (8 bits)
640 // +----------+
641 // | 11000000 |
642 // +----------+
643 //
644 // Dest after storing 0xF at offset 4 (in bits)
645 // +----------+
646 // | 11001111 |
647 // +----------+
648 //
649 // At a higher level, this translats to:
650 // 1. The source vector size (in bits) is a multiple of byte size.
651 // 2. The address of the store is aligned to the container type width
652 // boundary.
653 //
654 // EXAMPLE 1:
655 // Requires partial store:
656 // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
657 //
658 // EXAMPLE 2:
659 // Does not require a partial store:
660 // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
661 //
662 // TODO: Take linearizedInfo.linearizedOffset into account. This is
663 // currently not needed/used/exercised as all our tests set offset to 0.
664 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
665
666 if (!emulationRequiresPartialStores) {
667 // Basic case: storing full bytes.
668 auto numElements = origElements / emulatedPerContainerElem;
669 auto bitCast = vector::BitCastOp::create(
670 rewriter, loc, VectorType::get(numElements, containerElemTy),
671 op.getValueToStore());
672 rewriter.replaceOpWithNewOp<vector::StoreOp>(
673 op, bitCast.getResult(), memrefBase,
674 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
675 return success();
676 }
677
678 // Next, handle the case when sub-byte read-modify-write
679 // sequences are needed to emulate a vector store.
680 // Here is an example:
681 //
682 // Vector to store: vector<7xi2>
683 // Value to store: 11 11 11 11 11 11 11 (all ones)
684 //
685 // Destination: memref<12xi2>
686 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
687 //
688 // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
689 //
690 // Destination memref before:
691 //
692 // Byte 0 Byte 1 Byte 2
693 // +----------+----------+----------+
694 // | 00000000 | 00000000 | 00000000 |
695 // +----------+----------+----------+
696 //
697 // Destination memref after:
698 //
699 // Byte 0 Byte 1 Byte 2
700 // +----------+----------+----------+
701 // | 00001111 | 11111111 | 11000000 |
702 // +----------+----------+----------+
703 //
704 // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
705 // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
706 // requiring RMW access (atomicity is required).
707
708 // The index into the target memref we are storing to.
709 Value currentDestIndex =
710 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
711 // The index into the source vector we are currently processing.
712 auto currentSourceIndex = 0;
713
714 // Build a mask used for rmw.
715 auto subWidthStoreMaskType =
716 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
717
718 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
719
720 // 1. Partial width store for the leading byte.
721 // When the store address is not aligned to emulated width boundary, deal
722 // with the unaligned part so that the rest elements are aligned to width
723 // boundary.
724 auto frontSubWidthStoreElem =
725 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
726 emulatedPerContainerElem;
727 if (frontSubWidthStoreElem > 0) {
728 SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
729 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
730 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
731 origElements, true);
732 frontSubWidthStoreElem = origElements;
733 } else {
734 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
735 *foldedNumFrontPadElems, true);
736 }
737 auto frontMask = arith::ConstantOp::create(
738 rewriter, loc,
739 DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
740
741 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
742 auto value =
743 extractSliceIntoByte(rewriter, loc, valueToStore, 0,
744 frontSubWidthStoreElem, *foldedNumFrontPadElems);
745
746 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
747 cast<VectorValue>(value), frontMask.getResult());
748 }
749
750 if (currentSourceIndex >= origElements) {
751 rewriter.eraseOp(op);
752 return success();
753 }
754
755 // Increment the destination index by 1 to align to the emulated width
756 // boundary.
757 auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
758 currentDestIndex = arith::AddIOp::create(
759 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
760
761 // 2. Full width store for the inner output bytes.
762 // After the previous step, the store address is aligned to the emulated
763 // width boundary.
764 int64_t fullWidthStoreSize =
765 (origElements - currentSourceIndex) / emulatedPerContainerElem;
766 int64_t numNonFullWidthElements =
767 fullWidthStoreSize * emulatedPerContainerElem;
768 if (fullWidthStoreSize > 0) {
769 auto fullWidthStorePart = staticallyExtractSubvector(
770 rewriter, loc, valueToStore, currentSourceIndex,
771 numNonFullWidthElements);
772
773 auto originType = cast<VectorType>(fullWidthStorePart.getType());
774 auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
775 auto storeType = VectorType::get(
776 {originType.getNumElements() / emulatedPerContainerElem},
777 memrefElemType);
778 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
779 fullWidthStorePart);
780 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
781 currentDestIndex);
782
783 currentSourceIndex += numNonFullWidthElements;
784 currentDestIndex = arith::AddIOp::create(
785 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
786 arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
787 }
788
789 // 3. Partial width store for the trailing output byte.
790 // It is needed when the residual length is smaller than the emulated width,
791 // which is not covered in step 2 above.
792 auto remainingElements = origElements - currentSourceIndex;
793 if (remainingElements != 0) {
794 auto subWidthStorePart =
795 extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
796 currentSourceIndex, remainingElements, 0);
797
798 // Generate back mask.
799 auto maskValues = SmallVector<bool>(emulatedPerContainerElem, false);
800 std::fill_n(maskValues.begin(), remainingElements, 1);
801 auto backMask = arith::ConstantOp::create(
802 rewriter, loc,
803 DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
804
805 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
806 cast<VectorValue>(subWidthStorePart), backMask.getResult());
807 }
808
809 rewriter.eraseOp(op);
810 return success();
811 }
812
813private:
814 const bool disableAtomicRMW;
815};
816
817//===----------------------------------------------------------------------===//
818// ConvertVectorMaskedStore
819//===----------------------------------------------------------------------===//
820
821/// Converts `vector.maskedstore` operations on narrow element types to work
822/// with wider, byte-aligned container types by adjusting the mask and using
823/// bitcasting.
824///
825/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
826/// (each `i8` container element holds two `i4` values) and storing with an
827/// adjusted mask .
828struct ConvertVectorMaskedStore final
829 : OpConversionPattern<vector::MaskedStoreOp> {
830 using Base::Base;
831
832 LogicalResult
833 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
834 ConversionPatternRewriter &rewriter) const override {
835
836 // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
837 if (op.getValueToStore().getType().getRank() != 1)
838 return rewriter.notifyMatchFailure(
839 op, "Memref in vector.maskedstore op must be flattened beforehand.");
840
841 auto loc = op.getLoc();
842 auto containerElemTy =
843 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
844 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
845 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
846 int containerBits = containerElemTy.getIntOrFloatBitWidth();
847
848 // Check per-element alignment.
849 if (containerBits % emulatedBits != 0) {
850 return rewriter.notifyMatchFailure(
851 op, "impossible to pack emulated elements into container elements "
852 "(bit-wise misalignment)");
853 }
854
855 int emulatedPerContainerElem = containerBits / emulatedBits;
856 int origElements = op.getValueToStore().getType().getNumElements();
857 if (origElements % emulatedPerContainerElem != 0)
858 return failure();
859
860 auto stridedMetadata =
861 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
862 OpFoldResult linearizedIndicesOfr;
863 memref::LinearizedMemRefInfo linearizedInfo;
864 std::tie(linearizedInfo, linearizedIndicesOfr) =
866 rewriter, loc, emulatedBits, containerBits,
867 stridedMetadata.getConstifiedMixedOffset(),
868 stridedMetadata.getConstifiedMixedSizes(),
869 stridedMetadata.getConstifiedMixedStrides(),
870 getAsOpFoldResult(adaptor.getIndices()));
871 Value linearizedIndices =
872 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
873
874 // Load the whole data and use arith.select to handle the corner cases.
875 //
876 // As an example, for this masked store of i4 values:
877 //
878 // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
879 //
880 // and given these input values:
881 //
882 // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
883 // %0[%c0, %c0] =
884 // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
885 // %val_to_store =
886 // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
887 //
888 // we'll have the following i4 output:
889 //
890 // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
891 //
892 // Emulating the above using i8 will give:
893 //
894 // %compressed_mask = [1, 1, 1, 0] (4 * i1)
895 // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
896 // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
897 // %select_using_shifted_mask =
898 // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
899 // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
900 //
901 // Using the compressed mask to store %packed_data results in expected
902 // output.
903 //
904 // FIXME: Make an example based on the comment above work (see #115460 for
905 // reproducer).
906 FailureOr<Operation *> newMask = getCompressedMaskOp(
907 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
908 if (failed(newMask))
909 return failure();
910
911 auto numElements = (origElements + emulatedPerContainerElem - 1) /
912 emulatedPerContainerElem;
913 auto newType = VectorType::get(numElements, containerElemTy);
914 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
915 rewriter.getZeroAttr(newType));
916
917 auto newLoad = vector::MaskedLoadOp::create(
918 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
919 newMask.value()->getResult(0), passThru);
920
921 auto newBitCastType =
922 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
923 Value valueToStore =
924 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
925 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
926 op.getValueToStore(), valueToStore);
927 valueToStore =
928 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
929
930 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
931 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
932 valueToStore);
933 return success();
934 }
935};
936
937//===----------------------------------------------------------------------===//
938// ConvertVectorLoad
939//===----------------------------------------------------------------------===//
940
941/// Converts `vector.load` on narrow element types to work with
942/// wider, byte-aligned container types by adjusting load sizes and using
943/// bitcasting.
944///
945/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
946/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
947/// container holds two `i4` values) and bitcasting back.
948///
949/// There are cases where the number of elements to load is not byte-aligned. In
950/// those cases, loads are converted to byte-aligned, byte-sized loads and the
951/// target vector is extracted from the loaded vector.
952struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
953 using Base::Base;
954
955 LogicalResult
956 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
957 ConversionPatternRewriter &rewriter) const override {
958 // Prerequisite: memref in the vector.load op is flattened into 1-D.
959 if (op.getVectorType().getRank() != 1)
960 return rewriter.notifyMatchFailure(
961 op, "Memref in emulated vector ops must be flattened beforehand.");
962
963 auto loc = op.getLoc();
964 auto containerElemTy =
965 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
966 Type emulatedElemTy = op.getType().getElementType();
967 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
968 int containerBits = containerElemTy.getIntOrFloatBitWidth();
969
970 // Check per-element alignment.
971 if (containerBits % emulatedBits != 0) {
972 return rewriter.notifyMatchFailure(
973 op, "impossible to pack emulated elements into container elements "
974 "(bit-wise misalignment)");
975 }
976 int emulatedPerContainerElem = containerBits / emulatedBits;
977
978 // Adjust the number of elements to load when emulating narrow types,
979 // and then cast back to the original type with vector.bitcast op.
980 // For example, to emulate i4 to i8, the following op:
981 //
982 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
983 //
984 // can be replaced with
985 //
986 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
987 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
988 //
989 // There are cases where the number of elements to load is not byte-aligned,
990 // for example:
991 //
992 // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
993 //
994 // we will have to load extra bytes and extract the exact slice in between.
995 //
996 // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
997 // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
998 // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
999 // = [1]}
1000 // : vector<8xi2> to vector<3xi2>
1001 //
1002 // TODO: Currently the extract_strided_slice's attributes must be known at
1003 // compile time as they must be constants.
1004
1005 auto origElements = op.getVectorType().getNumElements();
1006 // Note, per-element-alignment was already verified above.
1007 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1008
1009 auto stridedMetadata =
1010 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1011
1012 OpFoldResult linearizedIndices;
1013 memref::LinearizedMemRefInfo linearizedInfo;
1014 std::tie(linearizedInfo, linearizedIndices) =
1016 rewriter, loc, emulatedBits, containerBits,
1017 stridedMetadata.getConstifiedMixedOffset(),
1018 stridedMetadata.getConstifiedMixedSizes(),
1019 stridedMetadata.getConstifiedMixedStrides(),
1020 getAsOpFoldResult(adaptor.getIndices()));
1021
1022 std::optional<int64_t> foldedIntraVectorOffset =
1023 isDivisibleInSize ? 0
1024 : getConstantIntValue(linearizedInfo.intraDataOffset);
1025
1026 // Always load enough elements which can cover the original elements.
1027 int64_t maxintraDataOffset =
1028 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1029 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1030 emulatedPerContainerElem);
1031 Value result =
1032 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1033 numElements, emulatedElemTy, containerElemTy);
1034
1035 if (!foldedIntraVectorOffset) {
1036 auto resultVector = arith::ConstantOp::create(
1037 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1039 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1040 linearizedInfo.intraDataOffset, origElements);
1041 } else if (!isDivisibleInSize) {
1043 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1044 }
1045 rewriter.replaceOp(op, result);
1046 return success();
1047 }
1048};
1049
1050//===----------------------------------------------------------------------===//
1051// ConvertVectorMaskedLoad
1052//===----------------------------------------------------------------------===//
1053
1054/// Converts `vector.maskedload` operations on narrow element types to work with
1055/// wider, byte-aligned container types by adjusting the mask and using
1056/// bitcasting.
1057///
1058/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1059/// bitcasting, since each `i8` container element holds two `i4` values.
1060struct ConvertVectorMaskedLoad final
1061 : OpConversionPattern<vector::MaskedLoadOp> {
1062 using Base::Base;
1063
1064 LogicalResult
1065 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1066 ConversionPatternRewriter &rewriter) const override {
1067 if (op.getVectorType().getRank() != 1)
1068 return rewriter.notifyMatchFailure(
1069 op, "Memref in emulated vector ops must be flattened beforehand.");
1070
1071 auto loc = op.getLoc();
1072
1073 auto containerElemTy =
1074 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1075 Type emulatedElemTy = op.getType().getElementType();
1076 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1077 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1078
1079 // Check per-element alignment.
1080 if (containerBits % emulatedBits != 0) {
1081 return rewriter.notifyMatchFailure(
1082 op, "impossible to pack emulated elements into container elements "
1083 "(bit-wise misalignment)");
1084 }
1085 int emulatedPerContainerElem = containerBits / emulatedBits;
1086
1087 // Adjust the number of elements to load when emulating narrow types,
1088 // and then cast back to the original type with vector.bitcast op.
1089 // For example, to emulate i4 to i8, the following op:
1090 //
1091 // %mask = vector.constant_mask [3] : vector<6xi1>
1092 // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1093 // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1094 //
1095 // can be replaced with
1096 //
1097 // %new_mask = vector.constant_mask [2] : vector<3xi1>
1098 // %new_pass_thru = vector.bitcast %pass_thru :
1099 // vector<6xi4> to vector<3xi8>
1100 // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1101 // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1102 // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1103 //
1104 // Since we are effectively loading 16 bits (2xi8) from the memref with the
1105 // new mask, while originally we only wanted to effectively load 12 bits
1106 // (3xi4) from the memref, we need to set the second half of the last i8
1107 // that was effectively loaded (i.e. the second i8) to %pass_thru.
1108 //
1109 // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1110 //
1111 // Given these input values:
1112 // %mask = [1, 1, 1, 0, 0, 0]
1113 // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1114 // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1115 //
1116 // we'll have:
1117 //
1118 // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1119 //
1120 // %new_mask = [1, 1, 0]
1121 // %new_pass_thru = [0x78, 0x9A, 0xBC]
1122 // %1 = [0x12, 0x34, 0xBC]
1123 // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1124 // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1125 //
1126 // TODO: Currently, only the even number of elements loading is supported.
1127 // To deal with the odd number of elements, one has to extract the
1128 // subvector at the proper offset after bit-casting.
1129 auto origType = op.getVectorType();
1130 auto origElements = origType.getNumElements();
1131 // Note, per-element-alignment was already verified above.
1132 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1133
1134 auto stridedMetadata =
1135 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1136 OpFoldResult linearizedIndices;
1137 memref::LinearizedMemRefInfo linearizedInfo;
1138 std::tie(linearizedInfo, linearizedIndices) =
1140 rewriter, loc, emulatedBits, containerBits,
1141 stridedMetadata.getConstifiedMixedOffset(),
1142 stridedMetadata.getConstifiedMixedSizes(),
1143 stridedMetadata.getConstifiedMixedStrides(),
1144 getAsOpFoldResult(adaptor.getIndices()));
1145
1146 std::optional<int64_t> foldedIntraVectorOffset =
1147 isDivisibleInSize ? 0
1148 : getConstantIntValue(linearizedInfo.intraDataOffset);
1149
1150 int64_t maxIntraDataOffset =
1151 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1152 FailureOr<Operation *> newMask =
1153 getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1154 emulatedPerContainerElem, maxIntraDataOffset);
1155 if (failed(newMask))
1156 return failure();
1157
1158 Value passthru = op.getPassThru();
1159
1160 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1161 emulatedPerContainerElem);
1162 auto loadType = VectorType::get(numElements, containerElemTy);
1163 auto newBitcastType =
1164 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1165
1166 auto emptyVector = arith::ConstantOp::create(
1167 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1168 if (!foldedIntraVectorOffset) {
1169 passthru = dynamicallyInsertSubVector(
1170 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1171 origElements);
1172 } else if (!isDivisibleInSize) {
1173 passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1174 *foldedIntraVectorOffset);
1175 }
1176 auto newPassThru =
1177 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1178
1179 // Generating the new masked load.
1180 auto newLoad = vector::MaskedLoadOp::create(
1181 rewriter, loc, loadType, adaptor.getBase(),
1182 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1183 newMask.value()->getResult(0), newPassThru);
1184
1185 // Setting the part that originally was not effectively loaded from memory
1186 // to pass through.
1187 auto bitCast =
1188 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1189
1190 Value mask = op.getMask();
1191 auto newSelectMaskType = VectorType::get(
1192 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1193 // TODO: try to fold if op's mask is constant
1194 auto emptyMask =
1195 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1196 rewriter.getZeroAttr(newSelectMaskType));
1197 if (!foldedIntraVectorOffset) {
1198 mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1199 linearizedInfo.intraDataOffset,
1200 origElements);
1201 } else if (!isDivisibleInSize) {
1202 mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1203 *foldedIntraVectorOffset);
1204 }
1205
1206 Value result =
1207 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1208 if (!foldedIntraVectorOffset) {
1210 rewriter, loc, result, op.getPassThru(),
1211 linearizedInfo.intraDataOffset, origElements);
1212 } else if (!isDivisibleInSize) {
1214 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1215 }
1216 rewriter.replaceOp(op, result);
1217
1218 return success();
1219 }
1220};
1221
1222/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1223///
1224/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1225/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1226/// (a multi-byte scalar, e.g. i16), where N is some integer.
1227///
1228/// Put differently, this method checks whether this would be valid:
1229///
1230/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1231///
1232/// EXAMPLES:
1233/// * vector<4xi4> -> i16 - yes (N = 1)
1234/// * vector<4xi4> -> i8 - yes (N = 2)
1235/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1236/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1237static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1238 Type multiByteScalarTy) {
1239 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1240
1241 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1242 int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1243
1244 assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1245 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1246 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1247
1248 int elemsPerMultiByte = multiByteBits / subByteBits;
1249
1250 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// ConvertVectorTransferRead
1255//===----------------------------------------------------------------------===//
1256
1257// TODO: Document-me
1258struct ConvertVectorTransferRead final
1259 : OpConversionPattern<vector::TransferReadOp> {
1260 using Base::Base;
1261
1262 LogicalResult
1263 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1264 ConversionPatternRewriter &rewriter) const override {
1265
1266 // Prerequisites: memref in the vector.transfer_read op is flattened into
1267 // 1-D.
1268 if (op.getVectorType().getRank() != 1)
1269 return rewriter.notifyMatchFailure(
1270 op, "Memref in emulated vector ops must be flattened beforehand.");
1271
1272 auto loc = op.getLoc();
1273 auto containerElemTy =
1274 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1275 Type emulatedElemTy = op.getType().getElementType();
1276 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1277 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1278
1279 // Check per-element alignment.
1280 if (containerBits % emulatedBits != 0) {
1281 return rewriter.notifyMatchFailure(
1282 op, "impossible to pack emulated elements into container elements "
1283 "(bit-wise misalignment)");
1284 }
1285 int emulatedPerContainerElem = containerBits / emulatedBits;
1286
1287 auto origElements = op.getVectorType().getNumElements();
1288
1289 // Note, per-element-alignment was already verified above.
1290 bool isDivisibleInSize =
1291 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1292
1293 // Pad the padding value with 0s on the left. These bits are discarded and
1294 // thus their values don't matter.
1295 Value padding = adaptor.getPadding();
1296 if (!padding.getType().isInteger()) {
1297 padding = arith::BitcastOp::create(
1298 rewriter, loc,
1299 IntegerType::get(rewriter.getContext(),
1300 padding.getType().getIntOrFloatBitWidth()),
1301 padding);
1302 }
1303 auto newPadding =
1304 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1305
1306 auto stridedMetadata =
1307 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1308
1309 OpFoldResult linearizedIndices;
1310 memref::LinearizedMemRefInfo linearizedInfo;
1311 std::tie(linearizedInfo, linearizedIndices) =
1313 rewriter, loc, emulatedBits, containerBits,
1314 stridedMetadata.getConstifiedMixedOffset(),
1315 stridedMetadata.getConstifiedMixedSizes(),
1316 stridedMetadata.getConstifiedMixedStrides(),
1317 getAsOpFoldResult(adaptor.getIndices()));
1318
1319 std::optional<int64_t> foldedIntraVectorOffset =
1320 isDivisibleInSize ? 0
1321 : getConstantIntValue(linearizedInfo.intraDataOffset);
1322
1323 int64_t maxIntraDataOffset =
1324 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1325 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1326 emulatedPerContainerElem);
1327
1328 auto newRead = vector::TransferReadOp::create(
1329 rewriter, loc, VectorType::get(numElements, containerElemTy),
1330 adaptor.getBase(),
1331 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1332 newPadding);
1333
1334 auto bitCast = vector::BitCastOp::create(
1335 rewriter, loc,
1336 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1337 newRead);
1338
1339 Value result = bitCast->getResult(0);
1340 if (!foldedIntraVectorOffset) {
1341 auto zeros = arith::ConstantOp::create(
1342 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1343 result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1344 linearizedInfo.intraDataOffset,
1345 origElements);
1346 } else if (!isDivisibleInSize) {
1348 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1349 }
1350 rewriter.replaceOp(op, result);
1351
1352 return success();
1353 }
1354};
1355} // end anonymous namespace
1356
1357//===----------------------------------------------------------------------===//
1358// RewriteBitCastOfTruncI
1359//===----------------------------------------------------------------------===//
1360
1361namespace {
1362
1363/// Helper struct to keep track of the provenance of a contiguous set of bits
1364/// in a source vector.
1365struct SourceElementRange {
1366 /// The index of the source vector element that contributes bits to *this.
1367 int64_t sourceElementIdx;
1368 /// The range of bits in the source vector element that contribute to *this.
1369 int64_t sourceBitBegin;
1370 int64_t sourceBitEnd;
1371};
1372
1373struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1374 /// Given the index of a SourceElementRange in the SourceElementRangeList,
1375 /// compute the amount of bits that need to be shifted to the left to get the
1376 /// bits in their final location. This shift amount is simply the sum of the
1377 /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1378 /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1379 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1380 int64_t res = 0;
1381 for (int64_t i = 0; i < shuffleIdx; ++i)
1382 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1383 return res;
1384 }
1385};
1386
1387/// Helper struct to enumerate the source elements and bit ranges that are
1388/// involved in a bitcast operation.
1389/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1390/// any 1-D vector shape and any source/target bitwidths.
1391/// This creates and holds a mapping of the form:
1392/// [dstVectorElementJ] ==
1393/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1394/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1395/// [0] = {0, [0-8)}
1396/// [1] = {0, [8-16)}
1397/// [2] = {0, [16-24)}
1398/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1399/// [0] = {0, [0, 10)}, {1, [0, 5)}
1400/// [1] = {1, [5, 10)}, {2, [0, 10)}
1401struct BitCastBitsEnumerator {
1402 BitCastBitsEnumerator(VectorType sourceVectorType,
1403 VectorType targetVectorType);
1404
1405 int64_t getMaxNumberOfEntries() {
1406 int64_t numVectors = 0;
1407 for (const auto &l : sourceElementRanges)
1408 numVectors = std::max(numVectors, (int64_t)l.size());
1409 return numVectors;
1410 }
1411
1412 VectorType sourceVectorType;
1413 VectorType targetVectorType;
1414 SmallVector<SourceElementRangeList> sourceElementRanges;
1415};
1416
1417/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1418/// advantage of high-level information to avoid leaving LLVM to scramble with
1419/// peephole optimizations.
1420/// BitCastBitsEnumerator encodes for each element of the target vector the
1421/// provenance of the bits in the source vector. We can "transpose" this
1422/// information to build a sequence of shuffles and bitwise ops that will
1423/// produce the desired result.
1424//
1425/// Consider the following motivating example:
1426/// ```
1427/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1428/// ```
1429//
1430/// BitCastBitsEnumerator contains the following information:
1431/// ```
1432/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1433/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1434/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1435/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1436/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1437/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1438/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1439/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1440/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1441/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1442/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1443/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1444/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1445/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1446/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1447/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1448/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1449/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1450/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1451/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1452/// ```
1453///
1454/// In the above, each row represents one target vector element and each
1455/// column represents one bit contribution from a source vector element.
1456/// The algorithm creates vector.shuffle operations (in this case there are 3
1457/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1458/// algorithm populates the bits as follows:
1459/// ```
1460/// src bits 0 ...
1461/// 1st shuffle |xxxxx |xx |...
1462/// 2nd shuffle | xxx| xxxxx |...
1463/// 3rd shuffle | | x|...
1464/// ```
1465//
1466/// The algorithm proceeds as follows:
1467/// 1. for each vector.shuffle, collect the source vectors that participate in
1468/// this shuffle. One source vector per target element of the resulting
1469/// vector.shuffle. If there is no source element contributing bits for the
1470/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1471/// 2 columns).
1472/// 2. represent the bitrange in the source vector as a mask. If there is no
1473/// source element contributing bits for the current vector.shuffle, take 0.
1474/// 3. shift right by the proper amount to align the source bitrange at
1475/// position 0. This is exactly the low end of the bitrange. For instance,
1476/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1477/// shift right by 3 to get the bits contributed by the source element #1
1478/// into position 0.
1479/// 4. shift left by the proper amount to to align to the desired position in
1480/// the result element vector. For instance, the contribution of the second
1481/// source element for the first row needs to be shifted by `5` to form the
1482/// first i8 result element.
1483///
1484/// Eventually, we end up building the sequence
1485/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1486/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1487/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1488struct BitCastRewriter {
1489 /// Helper metadata struct to hold the static quantities for the rewrite.
1490 struct Metadata {
1491 SmallVector<int64_t> shuffles;
1492 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1493 };
1494
1495 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1496
1497 /// Verify that general preconditions for the rewrite are met.
1498 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1499 VectorType preconditionType, Operation *op);
1500
1501 /// Precompute the metadata for the rewrite.
1502 SmallVector<BitCastRewriter::Metadata>
1503 precomputeMetadata(IntegerType shuffledElementType);
1504
1505 /// Rewrite one step of the sequence:
1506 /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1507 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1508 Value initialValue, Value runningResult,
1509 const BitCastRewriter::Metadata &metadata);
1510
1511private:
1512 /// Underlying enumerator that encodes the provenance of the bits in the each
1513 /// element of the result vector.
1514 BitCastBitsEnumerator enumerator;
1515};
1516
1517} // namespace
1518
1519[[maybe_unused]] static raw_ostream &
1521 for (const auto &l : vec) {
1522 for (auto it : llvm::enumerate(l)) {
1523 os << "{ " << it.value().sourceElementIdx << ": b@["
1524 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1525 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1526 }
1527 os << "\n";
1528 }
1529 return os;
1530}
1531
1532BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1533 VectorType targetVectorType)
1534 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1535
1536 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1537 "requires -D non-scalable vector type");
1538 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1539 "requires -D non-scalable vector type");
1540 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1541 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1542 LDBG() << "sourceVectorType: " << sourceVectorType;
1543
1544 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1545 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1546 LDBG() << "targetVectorType: " << targetVectorType;
1547
1548 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1549 (void)mostMinorSourceDim;
1550 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1551 "source and target bitwidths must match");
1552
1553 // Prepopulate one source element range per target element.
1554 sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1555 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1556 int64_t resultElement = resultBit / targetBitWidth;
1557 int64_t resultBitInElement = resultBit % targetBitWidth;
1558 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1559 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1560 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1561 targetBitWidth - resultBitInElement);
1562 sourceElementRanges[resultElement].push_back(
1563 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1564 resultBit += step;
1565 }
1566}
1567
1568BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1569 VectorType targetVectorType)
1570 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1571 LDBG() << "\n" << enumerator.sourceElementRanges;
1572}
1573
1574/// Verify that the precondition type meets the common preconditions for any
1575/// conversion.
1576static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1577 VectorType preconditionType,
1578 Operation *op) {
1579 if (!preconditionType || preconditionType.isScalable())
1580 return rewriter.notifyMatchFailure(op, "scalable vector");
1581
1582 // TODO: consider relaxing this restriction in the future if we find ways
1583 // to really work with subbyte elements across the MLIR/LLVM boundary.
1584 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1585 if (bitwidth % 8 != 0)
1586 return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1587
1588 return success();
1589}
1590
1591LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1592 VectorType preconditionType,
1593 Operation *op) {
1594 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1595 return rewriter.notifyMatchFailure(op, "types are not vector");
1596
1597 if (!preconditionType || preconditionType.getRank() != 1)
1598 return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1599
1600 return commonConversionPrecondition(rewriter, preconditionType, op);
1601}
1602
1603/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1604///
1605/// Alignment means that `subByteVecTy` can be packed into a vector of
1606/// `containerTy` elements. More specifically:
1607/// 1. The bit-width of `containerTy` is a multiple of the
1608/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1609/// this multiple is 4.
1610/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1611/// elements in `subByteVecTy`.
1612///
1613/// EXAMPLE 1:
1614/// `subByteVecTy = vector<2xi4>`, and
1615/// `containerTy = i16`
1616///
1617/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1618///
1619/// EXAMPLE 2:
1620/// `subByteVecTy = vector<3xi4>`, and
1621/// `containerTy = i16`
1622///
1623/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1624///
1625/// EXAMPLE 3:
1626/// `subByteVecTy = vector<3xi3>`, and
1627/// `containerTy = i16`
1628///
1629/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1630///
1631/// NOTE: This method assumes that common conversion preconditions are met. In
1632/// particular, `containerTy` is assumed to be a
1633/// multi-byte scalar type (e.g., i8, i16, i32).
1634static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1635 VectorType subByteVecTy,
1636 Type containerTy,
1637 Operation *op) {
1638 assert(containerTy.isIntOrFloat() &&
1639 "container element type is not a scalar");
1640
1641 // TODO: This is validating the inputs rather than checking the conditions
1642 // documented above. Replace with an assert.
1643 if (!subByteVecTy)
1644 return rewriter.notifyMatchFailure(op, "not a vector!");
1645
1646 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1647 unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1648
1649 // Enforced by the common pre-conditions.
1650 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1651
1652 // TODO: Add support other widths (when/if needed)
1653 if (subByteBits != 2 && subByteBits != 4)
1654 return rewriter.notifyMatchFailure(
1655 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1656
1657 // Condition 1 ("per-element" alignment)
1658 if (containerBits % subByteBits != 0)
1659 return rewriter.notifyMatchFailure(op, "unalagined element types");
1660
1661 // Condition 2 ("full" alignment)
1662 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1663 return rewriter.notifyMatchFailure(
1664 op, "not possible to fit this sub-byte vector type into a vector of "
1665 "the given multi-byte type");
1666
1667 return success();
1668}
1669
1670SmallVector<BitCastRewriter::Metadata>
1671BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1672 SmallVector<BitCastRewriter::Metadata> result;
1673 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1674 shuffleIdx < e; ++shuffleIdx) {
1675 SmallVector<int64_t> shuffles;
1676 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1677
1678 // Create the attribute quantities for the shuffle / mask / shift ops.
1679 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1680 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1681 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1682 : 0;
1683 shuffles.push_back(sourceElement);
1684
1685 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1686 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1687 : 0;
1688 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1689 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1690 : 0;
1691 IntegerAttr mask = IntegerAttr::get(
1692 shuffledElementType,
1693 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1694 bitLo, bitHi));
1695 masks.push_back(mask);
1696
1697 int64_t shiftRight = bitLo;
1698 shiftRightAmounts.push_back(
1699 IntegerAttr::get(shuffledElementType, shiftRight));
1700
1701 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1702 shiftLeftAmounts.push_back(
1703 IntegerAttr::get(shuffledElementType, shiftLeft));
1704 }
1705
1706 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1707 }
1708 return result;
1709}
1710
1711Value BitCastRewriter::genericRewriteStep(
1712 PatternRewriter &rewriter, Location loc, Value initialValue,
1713 Value runningResult, const BitCastRewriter::Metadata &metadata) {
1714 // Create vector.shuffle from the metadata.
1715 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1716 initialValue, metadata.shuffles);
1717
1718 // Intersect with the mask.
1719 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1720 auto constOp = arith::ConstantOp::create(
1721 rewriter, loc,
1722 DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1723 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1724
1725 // Align right on 0.
1726 auto shiftRightConstantOp = arith::ConstantOp::create(
1727 rewriter, loc,
1728 DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1729 Value shiftedRight =
1730 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1731
1732 // Shift bits left into their final position.
1733 auto shiftLeftConstantOp = arith::ConstantOp::create(
1734 rewriter, loc,
1735 DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1736 Value shiftedLeft =
1737 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1738
1739 runningResult =
1740 runningResult
1741 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1742 : shiftedLeft;
1743
1744 return runningResult;
1745}
1746
1747/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1748/// Where aligned means it satisfies the alignedConversionPreconditions.
1749///
1750/// Example:
1751/// vector<16x16xi2> -> vector<16x4xi8>
1752/// vector<16x16xi4> -> vector<16x8xi8>
1754 Value subByteVec) {
1755 auto srcVecType = cast<VectorType>(subByteVec.getType());
1756 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1757 assert(8 % srcBitwidth == 0 &&
1758 "Unsupported sub-byte type (not a divisor of i8)");
1759 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1760 SmallVector<int64_t> vecShape(srcVecType.getShape());
1761 // Adjust last dimension of the vector, so the total size remains the same.
1762 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1763 auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1764 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1765}
1766
1767/// Extracts a signed N-bit sequence from each element of a vector of bytes,
1768/// starting at the specified bit index.
1769/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1770///
1771/// Example for a single element:
1772/// Extract numBits=2 starting at bitIdx=2
1773/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1774/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1775/// target = [. . . . ^ ^ . .]
1776///
1777/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1778/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1779///
1780/// src = [01 01 11 10]
1781/// shl = arith.shl(src, 4) -> [11 10 00 00]
1782/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1784 Location loc, Value src,
1785 int bitIdx, int numBits) {
1786 auto srcType = cast<VectorType>(src.getType());
1787 Value shl = src;
1788 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1789 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1790 "Invalid bitIdx range");
1791 if (bitsToShiftLeft != 0) {
1792 Value shiftLeftValues = arith::ConstantOp::create(
1793 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1794 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1795 }
1796
1797 int8_t bitsToShiftRight = 8 - numBits;
1798 Value shiftRightValues = arith::ConstantOp::create(
1799 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1800 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1801 return shr;
1802}
1803
1804/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1805/// starting at the specified bit index.
1806/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1807///
1808/// Example for a single element:
1809/// Extract numBits=2 starting at bitIdx=2
1810/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1811/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1812/// target = [. . . . ^ ^ . .]
1813///
1814/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1815/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1816///
1817/// src = [01 01 10 10]
1818/// mask = [00 00 00 11]
1819/// shr = arith.shrui(src, 2) = [00 01 01 10]
1820/// result = arith.andi(shr, mask) = [00 00 00 10]
1821/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1822/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1823/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1824/// left when the index is 0.
1826 Location loc, Value src,
1827 int bitIdx, int numBits) {
1828 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1829 "Invalid bitIdx range");
1830 auto srcType = cast<VectorType>(src.getType());
1831 int8_t bitsToShiftRight = bitIdx;
1832 Value shr = src;
1833 if (bitsToShiftRight != 0) {
1834 Value shiftRightValues = arith::ConstantOp::create(
1835 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1836 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1837 }
1838 if (bitIdx + numBits == 8) {
1839 return shr;
1840 }
1841 uint8_t lowBitsMask = (1 << numBits) - 1;
1842 Value lowBitsMaskValues = arith::ConstantOp::create(
1843 rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1844 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1845}
1846
1848 std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1849
1850/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1851/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1853 Value srcValue, const ExtractNBitsFn &extFn) {
1854 [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1855 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1856 "Expected i4 type");
1857
1858 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1859 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1860
1861 // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1862 // byte are place in one vector and the high i4 elements in another vector.
1863 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1864 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1865
1866 // 3. Interleave low and high i8 elements.
1867 return vector::InterleaveOp::create(rewriter, loc, low, high);
1868}
1869
1870/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1871/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1873 Value srcValue, const ExtractNBitsFn &extFn) {
1874 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1875 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1876 "Expected i2 type");
1877
1878 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1879 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1880
1881 // 2. Extract each i2 element
1882 // Positon 0 (bits 0-1)
1883 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1884 // Position 1 (bits 2-3)
1885 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1886 // Position 2 (bits 4-5)
1887 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1888 // Position 3 (bits 6-7)
1889 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1890
1891 // 3. Interleave all 4 elements by first interleaving
1892 // even elements and then odd
1893 // vec0 = [0,0,0,0],...
1894 // vec1 = [1,1,1,1],...
1895 // vec2 = [2,2,2,2],...
1896 // vec3 = [3,3,3,3],...
1897 // 02 = [0,2,0,2,0,2,0,2],...
1898 // 13 = [1,3,1,3,1,3,1,3],...
1899 // 0213 = [0,1,2,3,...],...
1900 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1901 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1902 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1903 interleave13);
1904}
1905
1906/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1907/// ops to avoid leaving LLVM to scramble with peephole optimizations.
1909 Value srcValue) {
1910 VectorType srcVecType = cast<VectorType>(srcValue.getType());
1911 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1912 "Expected i8 type");
1913
1914 // 1. De-interleave low and high i8 elements.
1915 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1916
1917 // 2. Zero out the upper side of each low i8 element.
1918 constexpr int8_t i8LowBitMask = 0x0F;
1919 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1920 Value zeroOutMask = arith::ConstantOp::create(
1921 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1922 Value zeroOutLow = arith::AndIOp::create(
1923 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1924
1925 // 3. Move high i4 values to upper side of the byte.
1926 constexpr int8_t bitsToShift = 4;
1927 auto shiftValues = arith::ConstantOp::create(
1928 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1929 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1930 shiftValues);
1931
1932 // 4. Merge high and low i4 values.
1933 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1934
1935 // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1936 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1937 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1938}
1939
1940namespace {
1941/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1942/// advantage of high-level information to avoid leaving LLVM to scramble with
1943/// peephole optimizations.
1944struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1945 using Base::Base;
1946
1947 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1948 PatternRewriter &rewriter) const override {
1949 // The source must be a trunc op.
1950 auto truncOp =
1951 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1952 if (!truncOp)
1953 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1954
1955 // Set up the BitCastRewriter and verify the precondition.
1956 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1957 VectorType targetVectorType = bitCastOp.getResultVectorType();
1958 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1959 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1960 return failure();
1961
1962 // Perform the rewrite.
1963 Value truncValue = truncOp.getIn();
1964 auto shuffledElementType =
1965 cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1966 Value runningResult;
1967 for (const BitCastRewriter ::Metadata &metadata :
1968 bcr.precomputeMetadata(shuffledElementType)) {
1969 runningResult = bcr.genericRewriteStep(
1970 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1971 }
1972
1973 // Finalize the rewrite.
1974 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1975 shuffledElementType.getIntOrFloatBitWidth();
1976 if (narrowing) {
1977 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1978 rewriter.replaceOp(bitCastOp, runningResult);
1979 } else {
1980 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1981 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1982 }
1983 } else {
1984 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1985 rewriter.replaceOp(bitCastOp, runningResult);
1986 } else {
1987 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1988 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1989 }
1990 }
1991
1992 return success();
1993 }
1994};
1995} // namespace
1996
1997//===----------------------------------------------------------------------===//
1998// RewriteExtOfBitCast
1999//===----------------------------------------------------------------------===//
2000
2001namespace {
2002/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
2003/// take advantage of high-level information to avoid leaving LLVM to scramble
2004/// with peephole optimizations.
2005template <typename ExtOpType>
2006struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2007 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2008
2009 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2010 : OpRewritePattern<ExtOpType>(context, benefit) {}
2011
2012 LogicalResult matchAndRewrite(ExtOpType extOp,
2013 PatternRewriter &rewriter) const override {
2014 // The source must be a bitcast op.
2015 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2016 if (!bitCastOp)
2017 return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2018
2019 // Set up the BitCastRewriter and verify the precondition.
2020 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2021 VectorType targetVectorType = bitCastOp.getResultVectorType();
2022 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2023 if (failed(bcr.commonPrecondition(
2024 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2025 return failure();
2026
2027 // Perform the rewrite.
2028 Value runningResult;
2029 Value sourceValue = bitCastOp.getSource();
2030 auto shuffledElementType =
2031 cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2032 for (const BitCastRewriter::Metadata &metadata :
2033 bcr.precomputeMetadata(shuffledElementType)) {
2034 runningResult = bcr.genericRewriteStep(
2035 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2036 }
2037
2038 // Finalize the rewrite.
2039 bool narrowing =
2040 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2041 shuffledElementType.getIntOrFloatBitWidth();
2042 if (narrowing) {
2043 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2044 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2045 } else {
2046 rewriter.replaceOpWithNewOp<ExtOpType>(
2047 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2048 }
2049
2050 return success();
2051 }
2052};
2053
2054/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2055/// bitwise ops that take advantage of high-level information to avoid leaving
2056/// LLVM to scramble with peephole optimizations. Templated to choose between
2057/// signed and unsigned conversions.
2058///
2059/// EXAMPLE 1 (signed):
2060/// arith.extsi %in : vector<8xi4> to vector<8xi32>
2061/// is rewriten as:
2062/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2063/// %1 = arith.shli %0, 4 : vector<4xi8>
2064/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2065/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2066/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2067/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2068///
2069/// EXAMPLE 2 (fp):
2070/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2071/// is rewriten as:
2072/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2073/// %1 = arith.shli %0, 4 : vector<4xi8>
2074/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2075/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2076/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2077/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2078///
2079/// EXAMPLE 3 (unsigned):
2080/// arith.extui %in : vector<8xi4> to vector<8xi32>
2081/// is rewritten as:
2082/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2083/// %1 = arith.andi %0, 15 : vector<4xi8>
2084/// %2 = arith.shrui %0, 4 : vector<4xi8>
2085/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2086/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2087///
2088template <typename ConversionOpType, bool isSigned>
2089struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2090 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2091
2092 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2093 PatternRewriter &rewriter) const override {
2094 // Verify the preconditions.
2095 Value srcValue = conversionOp.getIn();
2096 VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2097 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2098
2099 if (failed(
2100 commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2101 return failure();
2102
2103 // Check general alignment preconditions.
2105 rewriter, srcVecType,
2106 /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2107 return failure();
2108
2109 // Perform the rewrite.
2110 Location loc = conversionOp.getLoc();
2111 const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2113 Value subByteExt;
2114 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2115 case 2:
2116 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2117 break;
2118 case 4:
2119 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2120 break;
2121 default:
2122 return failure();
2123 }
2124
2125 // Finalize the rewrite.
2126 rewriter.replaceOpWithNewOp<ConversionOpType>(
2127 conversionOp, conversionOp.getType(), subByteExt);
2128 return success();
2129 }
2130};
2131
2132/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2133/// bitwise ops that take advantage of high-level information to avoid leaving
2134/// LLVM to scramble with peephole optimizations.
2135///
2136/// For example:
2137/// arith.trunci %in : vector<8xi32> to vector<8xi4>
2138///
2139/// is rewriten as:
2140///
2141/// %cst = arith.constant dense<15> : vector<4xi8>
2142/// %cst_0 = arith.constant dense<4> : vector<4xi8>
2143/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2144/// %2 = arith.andi %0, %cst : vector<4xi8>
2145/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2146/// %4 = arith.ori %2, %3 : vector<4xi8>
2147/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2148///
2149struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2150 using Base::Base;
2151
2152 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2153 PatternRewriter &rewriter) const override {
2154 // Verify the preconditions.
2155 Value srcValue = truncOp.getIn();
2156 auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2157 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2158 if (!srcVecType || !dstVecType)
2159 return failure();
2160
2161 if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2162 return failure();
2163
2164 // TODO: Add support for truncating to i2.
2165 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2166 return failure();
2167
2168 // Check general alignment preconditions. We invert the src/dst type order
2169 // to reuse the existing precondition logic.
2171 rewriter, dstVecType,
2172 /*containerTy=*/rewriter.getI8Type(), truncOp)))
2173 return failure();
2174
2175 // Create a new iX -> i8 truncation op.
2176 Location loc = truncOp.getLoc();
2177 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2178 Value i8TruncVal =
2179 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2180
2181 // Rewrite the i8 -> i4 truncation part.
2182 Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2183
2184 // Finalize the rewrite.
2185 rewriter.replaceOp(truncOp, subByteTrunc);
2186 return success();
2187 }
2188};
2189
2190/// Rewrite a sub-byte vector transpose into a sequence of instructions that
2191/// perform the transpose on wider (byte) element types.
2192///
2193/// EXAMPLE:
2194/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2195///
2196/// is rewritten as:
2197///
2198/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2199/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2200/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2201///
2202struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2203 using Base::Base;
2204
2205 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2206 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2207
2208 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2209 PatternRewriter &rewriter) const override {
2210 // Precondition: sub-byte integer transpose.
2211 constexpr unsigned minNativeBitwidth = 8;
2212 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2213 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2214 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2215 return rewriter.notifyMatchFailure(transposeOp,
2216 "not a sub-byte transpose");
2217 }
2218
2219 // Perform the rewrite.
2220 Location loc = transposeOp.getLoc();
2221 // Signed/unsigned interpretation shouldn't matter here as we are just
2222 // transposing the elements and truncating them back to the original size.
2223 // TODO: Use unsigned extension (more efficient) when emulation or backend
2224 // support is available.
2225 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2226 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2227 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2228 transposeOp.getVector());
2229 Value newTranspose = vector::TransposeOp::create(
2230 rewriter, loc, extOp, transposeOp.getPermutation());
2231 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2232 rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2233 newTranspose);
2234 return success();
2235 }
2236};
2237
2238} // namespace
2239
2240//===----------------------------------------------------------------------===//
2241// Public Interface Definition
2242//===----------------------------------------------------------------------===//
2243
2244// The emulated type is inferred from the converted memref type.
2245void vector::populateVectorNarrowTypeEmulationPatterns(
2246 const arith::NarrowTypeEmulationConverter &typeConverter,
2247 RewritePatternSet &patterns, bool disableAtomicRMW) {
2248 // Populate `vector.*` conversion patterns.
2249 // TODO: #119553 support atomicity
2250 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2251 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2252 typeConverter, patterns.getContext());
2253
2254 // Populate `vector.*` store conversion patterns. The caller can choose
2255 // to avoid emitting atomic operations and reduce it to read-modify-write
2256 // sequence for stores if it is known there are no thread contentions.
2257 patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2258}
2259
2260void vector::populateVectorNarrowTypeRewritePatterns(
2261 RewritePatternSet &patterns, PatternBenefit benefit) {
2262 // TODO: Document what the emulated type is.
2263 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2264 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2265 benefit);
2266
2267 // Patterns for aligned cases. We set higher priority as they are expected to
2268 // generate better performance for aligned cases.
2269 // The container type is always i8.
2270 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2271 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2272 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2273 benefit.getBenefit() + 1);
2274 // The container type is always i8.
2275 patterns
2276 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2277 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2278 patterns.getContext(), benefit.getBenefit() + 1);
2279}
2280
2281// The container type is always i8.
2282void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2283 RewritePatternSet &patterns, PatternBenefit benefit) {
2284 patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2285}
2286
2287void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2288 arith::NarrowTypeEmulationConverter &typeConverter,
2289 RewritePatternSet &patterns) {
2291 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2292}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IntegerType getI4Type()
Definition Builders.cpp:57
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
IntegerType getI8Type()
Definition Builders.cpp:59
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_type_range getResultTypes()
Definition Operation.h:428
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.