MLIR 23.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate
10// narrow types that are not supported by the target hardware, e.g. i4
11// ("emulated type"), using wider types, e.g. i8 ("container type").
12//
13/// Currently, only power-of-two integer types are supported. These are
14/// converted to wider integers that are either 8 bits wide or wider.
15///
16/// TODO: Support for non-powers-of-two.
17//===----------------------------------------------------------------------===//
18
32#include "mlir/IR/Value.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
38#include <cstdint>
39#include <optional>
40
42
43using namespace mlir;
44
45#define DEBUG_TYPE "vector-narrow-type-emulation"
46
49
50//===----------------------------------------------------------------------===//
51// Utils
52//===----------------------------------------------------------------------===//
53
54/// Returns a compressed mask for the emulated vector. For example, when
55/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
56/// elements span two dest elements), this method compresses `vector<8xi1>`
57/// into `vector<2xi1>`.
58///
59/// The compressed/output mask value is set iff any mask in the corresponding
60/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
61/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
62/// following mask:
63///
64/// %mask = [1, 1, 0, 0, 0, 0]
65///
66/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
67/// will be added in the back to make the number of elements a multiple of
68/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
69///
70/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
71///
72/// then it will return the following new compressed mask:
73///
74/// %mask = [1, 1, 0, 0]
75///
76/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
77/// `numSrcElemsPerDest`.
78static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
79 Location loc, Value mask,
80 int numSrcElems,
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
83
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
86
87 auto numDestElems =
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
89 numSrcElemsPerDest;
90
91 Operation *maskOp = mask.getDefiningOp();
93 // TODO: add support to `vector.broadcast`.
94 // Finding the mask creation operation.
95 while (maskOp &&
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
97 maskOp)) {
98 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
101 }
102 }
103
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
105 maskOp))
106 return failure();
107
108 // Computing the "compressed" mask. All the emulation logic (i.e. computing
109 // new mask index) only happens on the last dimension of the vectors.
110 SmallVector<int64_t> maskShape(
111 cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
114 std::optional<Operation *> newMask =
116 .Case(
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
119 OperandRange maskOperands = createMaskOp.getOperands();
120 // The `vector.create_mask` op creates a mask arrangement
121 // without any zeros at the front. Also, because
122 // `numFrontPadElems` is strictly smaller than
123 // `numSrcElemsPerDest`, the compressed mask generated by
124 // padding the original mask by `numFrontPadElems` will not
125 // have any zeros at the front as well.
126 AffineExpr s0;
127 bindSymbols(rewriter.getContext(), s0);
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131 rewriter, loc, s0, origIndex);
132 SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133 newMaskOperands.push_back(
134 getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
136 newMaskOperands);
137 })
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
140 // Take the shape of mask, compress its trailing dimension:
141 SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
144 numSrcElemsPerDest);
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
146 maskDimSizes);
147 })
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
150 // TODO: Support multiple dimensions.
151 if (maskShape.size() != 1)
152 return std::nullopt;
153 // Rearrange the original mask values to cover the whole potential
154 // loading region. For example, in the case of using byte-size for
155 // emulation, given the following mask:
156 //
157 // %mask = [0, 1, 0, 1, 0, 0]
158 //
159 // With front offset of 1, the mask will be padded 0s in the front
160 // and back so that:
161 // 1. It is aligned with the effective loading bits
162 // 2. Its length is multiple of `numSrcElemPerDest` (and the total
163 // coverage size is mulitiple of bytes). The new mask will be like
164 // this before compressing:
165 //
166 // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
167 auto originalMask =
168 cast<DenseIntElementsAttr>(constantOp.getValue());
169 SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
173
174 // Compressing by combining every `numSrcElemsPerDest` elements:
175 SmallVector<bool> compressedMaskValues;
176 for (size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue = false;
179 for (int j = 0; j < numSrcElemsPerDest; ++j) {
180 combinedValue |= paddedMaskValues[i + j];
181 }
182 compressedMaskValues.push_back(combinedValue);
183 }
184 return arith::ConstantOp::create(
185 rewriter, loc,
186 DenseElementsAttr::get(newMaskType, compressedMaskValues));
187 });
188
189 if (!newMask)
190 return failure();
191
192 while (!extractOps.empty()) {
193 newMask =
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
197 }
198
199 return *newMask;
200}
201
202/// Extracts 1-D subvector from a 1-D vector.
203///
204/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
205/// from `src`, starting at `offset`. The result is also a rank-1 vector:
206///
207/// vector<numElemsToExtract x !elemType>
208///
209/// (`!elType` is the element type of the source vector). As `offset` is a known
210/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
211///
212/// EXAMPLE:
213/// %res = vector.extract_strided_slice %src
214/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
216 Value src, int64_t offset,
217 int64_t numElemsToExtract) {
218 auto vectorType = cast<VectorType>(src.getType());
219 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
222
223 // When extracting all available elements, just use the source vector as the
224 // result.
225 if (vectorType.getNumElements() == numElemsToExtract)
226 return src;
227
228 auto offsets = rewriter.getI64ArrayAttr({offset});
229 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
230 auto strides = rewriter.getI64ArrayAttr({1});
231
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
236 ->getResult(0);
237}
238
239/// Inserts 1-D subvector into a 1-D vector.
240///
241/// Inserts the input rank-1 source vector into the destination vector starting
242/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
243/// `vector.insert_strided_slice`.
244///
245/// EXAMPLE:
246/// %res = vector.insert_strided_slice %src, %dest
247/// {offsets = [%offset], strides [1]}
249 Value src, Value dest, int64_t offset) {
250 [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
251 [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
254
255 // If overwritting the destination vector, just return the source.
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257 return src;
258
259 auto offsets = rewriter.getI64ArrayAttr({offset});
260 auto strides = rewriter.getI64ArrayAttr({1});
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
263}
264
265/// Extracts 1-D subvector from a 1-D vector.
266///
267/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268/// from `src`, starting at `offset`. The result is also a rank-1 vector:
269///
270/// vector<numElemsToExtact x !elType>
271///
272/// (`!elType` is the element type of the source vector). As `offset` is assumed
273/// to be a _dynamic_ SSA value, this helper method generates a sequence of
274/// `vector.extract` + `vector.insert` pairs.
275///
276/// EXAMPLE:
277/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279/// %c1 = arith.constant 1 : index
280/// %idx2 = arith.addi %offset, %c1 : index
281/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283/// (...)
285 Value src, Value dest,
286 OpFoldResult offset,
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast<VectorType>(src.getType());
289 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290 // NOTE: We are unable to take the offset into account in the following
291 // assert, hence its still possible that the subvector is out-of-bounds even
292 // if the condition is true.
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
295
296 // When extracting all available elements, just use the source vector as the
297 // result.
298 if (srcVecTy.getNumElements() == numElemsToExtract)
299 return src;
300
301 for (int i = 0; i < numElemsToExtract; ++i) {
302 Value extractLoc =
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
305 rewriter, loc, rewriter.getIndexType(),
306 dyn_cast<Value>(offset),
307 arith::ConstantIndexOp::create(rewriter, loc, i));
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
310 }
311 return dest;
312}
313
314/// Inserts 1-D subvector into a 1-D vector.
315///
316/// Inserts the input rank-1 source vector into the destination vector starting
317/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
318/// uses a sequence of `vector.extract` + `vector.insert` pairs.
319///
320/// EXAMPLE:
321/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
322/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
323/// %c1 = arith.constant 1 : index
324/// %idx2 = arith.addi %offset, %c1 : index
325/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
326/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
327/// (...)
329 Value src, Value dest,
330 OpFoldResult offset,
331 int64_t numElemsToInsert) {
332 auto srcVecTy = cast<VectorType>(src.getType());
333 auto destVecTy = cast<VectorType>(dest.getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
336 (void)srcVecTy;
337 (void)destVecTy;
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
340 // NOTE: We are unable to take the offset into account in the following
341 // assert, hence its still possible that the subvector is out-of-bounds even
342 // if the condition is true.
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
345
346 Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
347 for (int64_t i = 0; i < numElemsToInsert; ++i) {
348 auto insertLoc =
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
351 rewriter, loc, rewriter.getIndexType(), destOffsetVal,
352 arith::ConstantIndexOp::create(rewriter, loc, i));
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
355 }
356 return dest;
357}
358
359/// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
360///
361/// Specifically, use `containerElemTy` for loading a vector of
362/// `emulatedElemTy`. The load location is given by `base` and
363/// `linearizedIndices`, and the load size is given by
364/// `numEmulatedElementsToLoad`.
366 Value base,
367 OpFoldResult linearizedIndices,
368 int64_t numContainerElemsToLoad,
369 Type emulatedElemTy,
370 Type containerElemTy) {
371 auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
372 emulatedElemTy.getIntOrFloatBitWidth();
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
375 base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
376 return vector::BitCastOp::create(
377 rewriter, loc,
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
379 emulatedElemTy),
380 newLoad);
381}
382
383/// Downcast two values to `downcastType`, then select values
384/// based on `mask`, and casts the result to `upcastType`.
386 VectorType downcastType,
387 VectorType upcastType, Value mask,
388 Value trueValue, Value falseValue) {
389 assert(
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.getType() != downcastType) {
394 trueValue =
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
396 }
397 if (falseValue.getType() != downcastType) {
398 falseValue =
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
400 }
401 Value selectedType =
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
403 // Upcast the selected value to the new type.
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
405}
406
407/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
408/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
409/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
410/// which elements to store.
411///
412/// Inputs:
413/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
414/// storeIdx = 2
415/// valueToStore = |3|3|3|3| : vector<4xi2>
416/// mask = |0|0|1|1| : vector<4xi1>
417///
418/// Result:
419/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
420static void atomicRMW(OpBuilder &builder, Location loc,
421 MemRefValue linearizedMemref, Value storeIdx,
422 VectorValue valueToStore, Value mask) {
423 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
424
425 // Create an atomic load-modify-write region using
426 // `memref.generic_atomic_rmw`.
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref, ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
430
431 OpBuilder::InsertionGuard guard(builder);
432 builder.setInsertionPointToStart(atomicOp.getBody());
433
434 // Load the original value from memory, and cast it to the original element
435 // type.
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType, ValueRange{origValue});
439
440 // Construct the final masked value and yield it.
441 Value maskedValue =
442 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
447}
448
449/// Generate a non-atomic read-modify-write sequence for storing to the emulated
450/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
451static void nonAtomicRMW(OpBuilder &builder, Location loc,
452 MemRefValue linearizedMemref, Value linearizedIndex,
453 VectorValue valueToStore, Value mask) {
454 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
455
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().getElementType());
458 Value origVecValue =
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
460 ValueRange{linearizedIndex});
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
462 origVecValue);
463
464 Value maskedValue =
465 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
468 linearizedIndex);
469}
470
471/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
472/// and insert it into an empty vector at `insertOffset`.
473/// Inputs:
474/// vec_in = |0|1|2|3| : vector<4xi2>
475/// extractOffset = 1
476/// sliceNumElements = 2
477/// insertOffset = 2
478/// Output:
479/// vec_out = |0|0|1|2| : vector<4xi2>
480static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
482 int64_t extractOffset,
483 int64_t sliceNumElements,
484 int64_t insertOffset) {
485 assert(vector.getType().getRank() == 1 && "expected 1-D vector");
486 auto vectorElementType = vector.getType().getElementType();
487 // TODO: update and use `alignedConversionPrecondition` in the place of
488 // these asserts.
489 assert(
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
496 rewriter, loc,
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
500 auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
501 extractOffset, sliceNumElements);
502 return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
503 insertOffset);
504}
505
506namespace {
507
508//===----------------------------------------------------------------------===//
509// ConvertVectorStore
510//===----------------------------------------------------------------------===//
511
512// Emulate `vector.store` using a multi-byte container type.
513//
514// When `assumeAligned` is true, store offsets are assumed to be aligned to
515// container element boundaries, so a store whose source vector fills whole
516// container elements (isDivisibleInSize) is emitted as a simple bitcast +
517// store without checking the offset. Stores that are not divisible in size
518// are rejected. This is useful for downstream users that have already
519// ensured alignment.
520//
521// The container type is obtained through Op adaptor and would normally be
522// generated via `NarrowTypeEmulationConverter`.
523//
524// EXAMPLE 1
525// (aligned store of i4, emulated using i8 as the container type)
526//
527// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
528//
529// is rewritten as:
530//
531// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
532// vector.store %src_bitcast, %dest_bitcast[%idx]
533// : memref<16xi8>, vector<4xi8>
534//
535// EXAMPLE 2
536// (unaligned store of i2, emulated using i8 as the container type)
537//
538// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
539//
540// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
541// is modelled using 3 bytes:
542//
543// Byte 0 Byte 1 Byte 2
544// +----------+----------+----------+
545// | oooooooo | ooooNNNN | NNoooooo |
546// +----------+----------+----------+
547//
548// N - (N)ew entries (i.e. to be overwritten by vector.store)
549// o - (o)ld entries (to be preserved)
550//
551// For the generated output in the non-atomic case, see:
552// * @vector_store_i2_const_index_two_partial_stores`
553// in:
554// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
555//
556// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
557// `false` to generate non-atomic RMW sequences.
558struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
559 using Base::Base;
560
561 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW,
562 bool assumeAligned)
563 : OpConversionPattern<vector::StoreOp>(context),
564 disableAtomicRMW(disableAtomicRMW), assumeAligned(assumeAligned) {}
565
566 LogicalResult
567 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569
570 if (op.getValueToStore().getType().getRank() != 1)
571 return rewriter.notifyMatchFailure(op,
572 "only 1-D vectors are supported ATM");
573
574 auto loc = op.getLoc();
575
576 auto valueToStore = cast<VectorValue>(op.getValueToStore());
577 auto containerElemTy =
578 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
579 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
580 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
581 int containerBits = containerElemTy.getIntOrFloatBitWidth();
582
583 // Check per-element alignment.
584 if (containerBits % emulatedBits != 0) {
585 return rewriter.notifyMatchFailure(
586 op, "impossible to pack emulated elements into container elements "
587 "(bit-wise misalignment)");
588 }
589 int emulatedPerContainerElem = containerBits / emulatedBits;
590
591 // Adjust the number of elements to store when emulating narrow types.
592 // Here only the 1-D vector store is considered, and the N-D memref types
593 // should be linearized.
594 // For example, to emulate i4 to i8, the following op:
595 //
596 // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
597 //
598 // can be replaced with
599 //
600 // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
601 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
602 // vector<4xi8>
603
604 auto origElements = valueToStore.getType().getNumElements();
605 // Note, per-element-alignment was already verified above.
606 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
607
608 // In assume-aligned mode, isDivisibleInSize alone is sufficient — the
609 // caller guarantees that store offsets are aligned to container element
610 // boundaries.
611 if (assumeAligned) {
612 if (!isDivisibleInSize)
613 return rewriter.notifyMatchFailure(
614 op, "the source vector does not fill whole container elements "
615 "(not divisible in size)");
616
617 auto stridedMetadata =
618 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
619 OpFoldResult linearizedIndices;
620 std::tie(std::ignore, linearizedIndices) =
622 rewriter, loc, emulatedBits, containerBits,
623 stridedMetadata.getConstifiedMixedOffset(),
624 stridedMetadata.getConstifiedMixedSizes(),
625 stridedMetadata.getConstifiedMixedStrides(),
626 getAsOpFoldResult(adaptor.getIndices()));
627 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
628 int numElements = origElements / emulatedPerContainerElem;
629 auto bitCast = vector::BitCastOp::create(
630 rewriter, loc, VectorType::get(numElements, containerElemTy),
631 op.getValueToStore());
632 rewriter.replaceOpWithNewOp<vector::StoreOp>(
633 op, bitCast.getResult(), memrefBase,
634 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
635 return success();
636 }
637
638 // Do the trailing dim for source and destination match? If yes, then the
639 // corresponding index must be 0.
640 // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
641 // However, that makes some tests fail, so we need to audit first.
642 auto trailingDim = op.getBase().getType().getShape().back();
643 bool trailingDimsMatch =
644 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
645
646 auto stridedMetadata =
647 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
648
649 // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
650 // non-zero. As such, this is not needed.
651 OpFoldResult linearizedIndices;
652 memref::LinearizedMemRefInfo linearizedInfo;
653 std::tie(linearizedInfo, linearizedIndices) =
655 rewriter, loc, emulatedBits, containerBits,
656 stridedMetadata.getConstifiedMixedOffset(),
657 stridedMetadata.getConstifiedMixedSizes(),
658 stridedMetadata.getConstifiedMixedStrides(),
659 getAsOpFoldResult(adaptor.getIndices()));
660
661 std::optional<int64_t> foldedNumFrontPadElems =
662 (isDivisibleInSize && trailingDimsMatch)
663 ? 0
664 : getConstantIntValue(linearizedInfo.intraDataOffset);
665
666 if (!foldedNumFrontPadElems) {
667 return rewriter.notifyMatchFailure(
668 op, "subbyte store emulation: dynamic front padding size is "
669 "not yet implemented");
670 }
671
672 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
673
674 // RMWs are not needed when:
675 // * no _partial_ stores are required.
676 // A partial store is defined as a store in which only a part of the
677 // container element is overwritten, e.g.
678 //
679 // Dest before (8 bits)
680 // +----------+
681 // | 11000000 |
682 // +----------+
683 //
684 // Dest after storing 0xF at offset 4 (in bits)
685 // +----------+
686 // | 11001111 |
687 // +----------+
688 //
689 // At a higher level, this translats to:
690 // 1. The source vector size (in bits) is a multiple of byte size.
691 // 2. The address of the store is aligned to the container type width
692 // boundary.
693 //
694 // EXAMPLE 1:
695 // Requires partial store:
696 // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
697 //
698 // EXAMPLE 2:
699 // Does not require a partial store:
700 // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
701 //
702 // TODO: Take linearizedInfo.linearizedOffset into account. This is
703 // currently not needed/used/exercised as all our tests set offset to 0.
704 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
705
706 if (!emulationRequiresPartialStores) {
707 // Basic case: storing full bytes.
708 auto numElements = origElements / emulatedPerContainerElem;
709 auto bitCast = vector::BitCastOp::create(
710 rewriter, loc, VectorType::get(numElements, containerElemTy),
711 op.getValueToStore());
712 rewriter.replaceOpWithNewOp<vector::StoreOp>(
713 op, bitCast.getResult(), memrefBase,
714 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
715 return success();
716 }
717
718 // Next, handle the case when sub-byte read-modify-write
719 // sequences are needed to emulate a vector store.
720 // Here is an example:
721 //
722 // Vector to store: vector<7xi2>
723 // Value to store: 11 11 11 11 11 11 11 (all ones)
724 //
725 // Destination: memref<12xi2>
726 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
727 //
728 // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
729 //
730 // Destination memref before:
731 //
732 // Byte 0 Byte 1 Byte 2
733 // +----------+----------+----------+
734 // | 00000000 | 00000000 | 00000000 |
735 // +----------+----------+----------+
736 //
737 // Destination memref after:
738 //
739 // Byte 0 Byte 1 Byte 2
740 // +----------+----------+----------+
741 // | 00001111 | 11111111 | 11000000 |
742 // +----------+----------+----------+
743 //
744 // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
745 // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
746 // requiring RMW access (atomicity is required).
747
748 // The index into the target memref we are storing to.
749 Value currentDestIndex =
750 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
751 // The index into the source vector we are currently processing.
752 auto currentSourceIndex = 0;
753
754 // Build a mask used for rmw.
755 auto subWidthStoreMaskType =
756 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
757
758 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
759
760 // 1. Partial width store for the leading byte.
761 // When the store address is not aligned to emulated width boundary, deal
762 // with the unaligned part so that the rest elements are aligned to width
763 // boundary.
764 auto frontSubWidthStoreElem =
765 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
766 emulatedPerContainerElem;
767 if (frontSubWidthStoreElem > 0) {
768 SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
769 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
770 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
771 origElements, true);
772 frontSubWidthStoreElem = origElements;
773 } else {
774 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
775 *foldedNumFrontPadElems, true);
776 }
777 auto frontMask = arith::ConstantOp::create(
778 rewriter, loc,
779 DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
780
781 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
782 auto value =
783 extractSliceIntoByte(rewriter, loc, valueToStore, 0,
784 frontSubWidthStoreElem, *foldedNumFrontPadElems);
785
786 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
787 cast<VectorValue>(value), frontMask.getResult());
788 }
789
790 if (currentSourceIndex >= origElements) {
791 rewriter.eraseOp(op);
792 return success();
793 }
794
795 // Increment the destination index by 1 to align to the emulated width
796 // boundary.
797 auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
798 currentDestIndex = arith::AddIOp::create(
799 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
800
801 // 2. Full width store for the inner output bytes.
802 // After the previous step, the store address is aligned to the emulated
803 // width boundary.
804 int64_t fullWidthStoreSize =
805 (origElements - currentSourceIndex) / emulatedPerContainerElem;
806 int64_t numNonFullWidthElements =
807 fullWidthStoreSize * emulatedPerContainerElem;
808 if (fullWidthStoreSize > 0) {
809 auto fullWidthStorePart = staticallyExtractSubvector(
810 rewriter, loc, valueToStore, currentSourceIndex,
811 numNonFullWidthElements);
812
813 auto originType = cast<VectorType>(fullWidthStorePart.getType());
814 auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
815 auto storeType = VectorType::get(
816 {originType.getNumElements() / emulatedPerContainerElem},
817 memrefElemType);
818 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
819 fullWidthStorePart);
820 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
821 currentDestIndex);
822
823 currentSourceIndex += numNonFullWidthElements;
824 currentDestIndex = arith::AddIOp::create(
825 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
826 arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
827 }
828
829 // 3. Partial width store for the trailing output byte.
830 // It is needed when the residual length is smaller than the emulated width,
831 // which is not covered in step 2 above.
832 auto remainingElements = origElements - currentSourceIndex;
833 if (remainingElements != 0) {
834 auto subWidthStorePart =
835 extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
836 currentSourceIndex, remainingElements, 0);
837
838 // Generate back mask.
839 auto maskValues = SmallVector<bool>(emulatedPerContainerElem, false);
840 std::fill_n(maskValues.begin(), remainingElements, 1);
841 auto backMask = arith::ConstantOp::create(
842 rewriter, loc,
843 DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
844
845 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
846 cast<VectorValue>(subWidthStorePart), backMask.getResult());
847 }
848
849 rewriter.eraseOp(op);
850 return success();
851 }
852
853private:
854 const bool disableAtomicRMW;
855 const bool assumeAligned;
856};
857
858//===----------------------------------------------------------------------===//
859// ConvertVectorMaskedStore
860//===----------------------------------------------------------------------===//
861
862/// Converts `vector.maskedstore` operations on narrow element types to work
863/// with wider, byte-aligned container types by adjusting the mask and using
864/// bitcasting.
865///
866/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
867/// (each `i8` container element holds two `i4` values) and storing with an
868/// adjusted mask .
869struct ConvertVectorMaskedStore final
870 : OpConversionPattern<vector::MaskedStoreOp> {
871 using Base::Base;
872
873 LogicalResult
874 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
875 ConversionPatternRewriter &rewriter) const override {
876
877 // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
878 if (op.getValueToStore().getType().getRank() != 1)
879 return rewriter.notifyMatchFailure(
880 op, "Memref in vector.maskedstore op must be flattened beforehand.");
881
882 auto loc = op.getLoc();
883 auto containerElemTy =
884 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
885 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
886 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
887 int containerBits = containerElemTy.getIntOrFloatBitWidth();
888
889 // Check per-element alignment.
890 if (containerBits % emulatedBits != 0) {
891 return rewriter.notifyMatchFailure(
892 op, "impossible to pack emulated elements into container elements "
893 "(bit-wise misalignment)");
894 }
895
896 int emulatedPerContainerElem = containerBits / emulatedBits;
897 int origElements = op.getValueToStore().getType().getNumElements();
898 if (origElements % emulatedPerContainerElem != 0)
899 return failure();
900
901 auto stridedMetadata =
902 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
903 OpFoldResult linearizedIndicesOfr;
904 memref::LinearizedMemRefInfo linearizedInfo;
905 std::tie(linearizedInfo, linearizedIndicesOfr) =
907 rewriter, loc, emulatedBits, containerBits,
908 stridedMetadata.getConstifiedMixedOffset(),
909 stridedMetadata.getConstifiedMixedSizes(),
910 stridedMetadata.getConstifiedMixedStrides(),
911 getAsOpFoldResult(adaptor.getIndices()));
912 Value linearizedIndices =
913 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
914
915 // Load the whole data and use arith.select to handle the corner cases.
916 //
917 // As an example, for this masked store of i4 values:
918 //
919 // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
920 //
921 // and given these input values:
922 //
923 // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
924 // %0[%c0, %c0] =
925 // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
926 // %val_to_store =
927 // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
928 //
929 // we'll have the following i4 output:
930 //
931 // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
932 //
933 // Emulating the above using i8 will give:
934 //
935 // %compressed_mask = [1, 1, 1, 0] (4 * i1)
936 // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
937 // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
938 // %select_using_shifted_mask =
939 // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
940 // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
941 //
942 // Using the compressed mask to store %packed_data results in expected
943 // output.
944 //
945 // FIXME: Make an example based on the comment above work (see #115460 for
946 // reproducer).
947 FailureOr<Operation *> newMask = getCompressedMaskOp(
948 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
949 if (failed(newMask))
950 return failure();
951
952 auto numElements = (origElements + emulatedPerContainerElem - 1) /
953 emulatedPerContainerElem;
954 auto newType = VectorType::get(numElements, containerElemTy);
955 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
956 rewriter.getZeroAttr(newType));
957
958 auto newLoad = vector::MaskedLoadOp::create(
959 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
960 newMask.value()->getResult(0), passThru);
961
962 auto newBitCastType =
963 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
964 Value valueToStore =
965 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
966 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
967 op.getValueToStore(), valueToStore);
968 valueToStore =
969 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
970
971 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
972 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
973 valueToStore);
974 return success();
975 }
976};
977
978//===----------------------------------------------------------------------===//
979// ConvertVectorLoad
980//===----------------------------------------------------------------------===//
981
982/// Converts `vector.load` on narrow element types to work with
983/// wider, byte-aligned container types by adjusting load sizes and using
984/// bitcasting.
985///
986/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
987/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
988/// container holds two `i4` values) and bitcasting back.
989///
990/// There are cases where the number of elements to load is not byte-aligned. In
991/// those cases, loads are converted to byte-aligned, byte-sized loads and the
992/// target vector is extracted from the loaded vector.
993struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
994 using Base::Base;
995
996 LogicalResult
997 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
998 ConversionPatternRewriter &rewriter) const override {
999 // Prerequisite: memref in the vector.load op is flattened into 1-D.
1000 if (op.getVectorType().getRank() != 1)
1001 return rewriter.notifyMatchFailure(
1002 op, "Memref in emulated vector ops must be flattened beforehand.");
1003
1004 auto loc = op.getLoc();
1005 auto containerElemTy =
1006 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1007 Type emulatedElemTy = op.getType().getElementType();
1008 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1009 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1010
1011 // Check per-element alignment.
1012 if (containerBits % emulatedBits != 0) {
1013 return rewriter.notifyMatchFailure(
1014 op, "impossible to pack emulated elements into container elements "
1015 "(bit-wise misalignment)");
1016 }
1017 int emulatedPerContainerElem = containerBits / emulatedBits;
1018
1019 // Adjust the number of elements to load when emulating narrow types,
1020 // and then cast back to the original type with vector.bitcast op.
1021 // For example, to emulate i4 to i8, the following op:
1022 //
1023 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
1024 //
1025 // can be replaced with
1026 //
1027 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
1028 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
1029 //
1030 // There are cases where the number of elements to load is not byte-aligned,
1031 // for example:
1032 //
1033 // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
1034 //
1035 // we will have to load extra bytes and extract the exact slice in between.
1036 //
1037 // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
1038 // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
1039 // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
1040 // = [1]}
1041 // : vector<8xi2> to vector<3xi2>
1042 //
1043 // TODO: Currently the extract_strided_slice's attributes must be known at
1044 // compile time as they must be constants.
1045
1046 auto origElements = op.getVectorType().getNumElements();
1047 // Note, per-element-alignment was already verified above.
1048 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1049
1050 auto stridedMetadata =
1051 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1052
1053 OpFoldResult linearizedIndices;
1054 memref::LinearizedMemRefInfo linearizedInfo;
1055 std::tie(linearizedInfo, linearizedIndices) =
1057 rewriter, loc, emulatedBits, containerBits,
1058 stridedMetadata.getConstifiedMixedOffset(),
1059 stridedMetadata.getConstifiedMixedSizes(),
1060 stridedMetadata.getConstifiedMixedStrides(),
1061 getAsOpFoldResult(adaptor.getIndices()));
1062
1063 std::optional<int64_t> foldedIntraVectorOffset =
1064 isDivisibleInSize ? 0
1065 : getConstantIntValue(linearizedInfo.intraDataOffset);
1066
1067 // Always load enough elements which can cover the original elements.
1068 int64_t maxintraDataOffset =
1069 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1070 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1071 emulatedPerContainerElem);
1072 Value result =
1073 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1074 numElements, emulatedElemTy, containerElemTy);
1075
1076 if (!foldedIntraVectorOffset) {
1077 auto resultVector = arith::ConstantOp::create(
1078 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1080 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1081 linearizedInfo.intraDataOffset, origElements);
1082 } else if (!isDivisibleInSize) {
1084 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1085 }
1086 rewriter.replaceOp(op, result);
1087 return success();
1088 }
1089};
1090
1091//===----------------------------------------------------------------------===//
1092// ConvertVectorMaskedLoad
1093//===----------------------------------------------------------------------===//
1094
1095/// Converts `vector.maskedload` operations on narrow element types to work with
1096/// wider, byte-aligned container types by adjusting the mask and using
1097/// bitcasting.
1098///
1099/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1100/// bitcasting, since each `i8` container element holds two `i4` values.
1101struct ConvertVectorMaskedLoad final
1102 : OpConversionPattern<vector::MaskedLoadOp> {
1103 using Base::Base;
1104
1105 LogicalResult
1106 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter) const override {
1108 if (op.getVectorType().getRank() != 1)
1109 return rewriter.notifyMatchFailure(
1110 op, "Memref in emulated vector ops must be flattened beforehand.");
1111
1112 auto loc = op.getLoc();
1113
1114 auto containerElemTy =
1115 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1116 Type emulatedElemTy = op.getType().getElementType();
1117 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1118 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1119
1120 // Check per-element alignment.
1121 if (containerBits % emulatedBits != 0) {
1122 return rewriter.notifyMatchFailure(
1123 op, "impossible to pack emulated elements into container elements "
1124 "(bit-wise misalignment)");
1125 }
1126 int emulatedPerContainerElem = containerBits / emulatedBits;
1127
1128 // Adjust the number of elements to load when emulating narrow types,
1129 // and then cast back to the original type with vector.bitcast op.
1130 // For example, to emulate i4 to i8, the following op:
1131 //
1132 // %mask = vector.constant_mask [3] : vector<6xi1>
1133 // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1134 // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1135 //
1136 // can be replaced with
1137 //
1138 // %new_mask = vector.constant_mask [2] : vector<3xi1>
1139 // %new_pass_thru = vector.bitcast %pass_thru :
1140 // vector<6xi4> to vector<3xi8>
1141 // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1142 // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1143 // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1144 //
1145 // Since we are effectively loading 16 bits (2xi8) from the memref with the
1146 // new mask, while originally we only wanted to effectively load 12 bits
1147 // (3xi4) from the memref, we need to set the second half of the last i8
1148 // that was effectively loaded (i.e. the second i8) to %pass_thru.
1149 //
1150 // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1151 //
1152 // Given these input values:
1153 // %mask = [1, 1, 1, 0, 0, 0]
1154 // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1155 // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1156 //
1157 // we'll have:
1158 //
1159 // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1160 //
1161 // %new_mask = [1, 1, 0]
1162 // %new_pass_thru = [0x78, 0x9A, 0xBC]
1163 // %1 = [0x12, 0x34, 0xBC]
1164 // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1165 // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1166 //
1167 // TODO: Currently, only the even number of elements loading is supported.
1168 // To deal with the odd number of elements, one has to extract the
1169 // subvector at the proper offset after bit-casting.
1170 auto origType = op.getVectorType();
1171 auto origElements = origType.getNumElements();
1172 // Note, per-element-alignment was already verified above.
1173 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1174
1175 auto stridedMetadata =
1176 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1177 OpFoldResult linearizedIndices;
1178 memref::LinearizedMemRefInfo linearizedInfo;
1179 std::tie(linearizedInfo, linearizedIndices) =
1181 rewriter, loc, emulatedBits, containerBits,
1182 stridedMetadata.getConstifiedMixedOffset(),
1183 stridedMetadata.getConstifiedMixedSizes(),
1184 stridedMetadata.getConstifiedMixedStrides(),
1185 getAsOpFoldResult(adaptor.getIndices()));
1186
1187 std::optional<int64_t> foldedIntraVectorOffset =
1188 isDivisibleInSize ? 0
1189 : getConstantIntValue(linearizedInfo.intraDataOffset);
1190
1191 int64_t maxIntraDataOffset =
1192 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1193 FailureOr<Operation *> newMask =
1194 getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1195 emulatedPerContainerElem, maxIntraDataOffset);
1196 if (failed(newMask))
1197 return failure();
1198
1199 Value passthru = op.getPassThru();
1200
1201 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1202 emulatedPerContainerElem);
1203 auto loadType = VectorType::get(numElements, containerElemTy);
1204 auto newBitcastType =
1205 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1206
1207 auto emptyVector = arith::ConstantOp::create(
1208 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1209 if (!foldedIntraVectorOffset) {
1210 passthru = dynamicallyInsertSubVector(
1211 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1212 origElements);
1213 } else if (!isDivisibleInSize) {
1214 passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1215 *foldedIntraVectorOffset);
1216 }
1217 auto newPassThru =
1218 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1219
1220 // Generating the new masked load.
1221 auto newLoad = vector::MaskedLoadOp::create(
1222 rewriter, loc, loadType, adaptor.getBase(),
1223 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1224 newMask.value()->getResult(0), newPassThru);
1225
1226 // Setting the part that originally was not effectively loaded from memory
1227 // to pass through.
1228 auto bitCast =
1229 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1230
1231 Value mask = op.getMask();
1232 auto newSelectMaskType = VectorType::get(
1233 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1234 // TODO: try to fold if op's mask is constant
1235 auto emptyMask =
1236 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1237 rewriter.getZeroAttr(newSelectMaskType));
1238 if (!foldedIntraVectorOffset) {
1239 mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1240 linearizedInfo.intraDataOffset,
1241 origElements);
1242 } else if (!isDivisibleInSize) {
1243 mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1244 *foldedIntraVectorOffset);
1245 }
1246
1247 Value result =
1248 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1249 if (!foldedIntraVectorOffset) {
1251 rewriter, loc, result, op.getPassThru(),
1252 linearizedInfo.intraDataOffset, origElements);
1253 } else if (!isDivisibleInSize) {
1255 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1256 }
1257 rewriter.replaceOp(op, result);
1258
1259 return success();
1260 }
1261};
1262
1263/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1264///
1265/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1266/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1267/// (a multi-byte scalar, e.g. i16), where N is some integer.
1268///
1269/// Put differently, this method checks whether this would be valid:
1270///
1271/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1272///
1273/// EXAMPLES:
1274/// * vector<4xi4> -> i16 - yes (N = 1)
1275/// * vector<4xi4> -> i8 - yes (N = 2)
1276/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1277/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1278static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1279 Type multiByteScalarTy) {
1280 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1281
1282 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1283 int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1284
1285 assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1286 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1287 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1288
1289 int elemsPerMultiByte = multiByteBits / subByteBits;
1290
1291 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1292}
1293
1294//===----------------------------------------------------------------------===//
1295// ConvertVectorTransferRead
1296//===----------------------------------------------------------------------===//
1297
1298// TODO: Document-me
1299struct ConvertVectorTransferRead final
1300 : OpConversionPattern<vector::TransferReadOp> {
1301 using Base::Base;
1302
1303 LogicalResult
1304 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1305 ConversionPatternRewriter &rewriter) const override {
1306
1307 // Prerequisites: memref in the vector.transfer_read op is flattened into
1308 // 1-D.
1309 if (op.getVectorType().getRank() != 1)
1310 return rewriter.notifyMatchFailure(
1311 op, "Memref in emulated vector ops must be flattened beforehand.");
1312
1313 auto loc = op.getLoc();
1314 auto containerElemTy =
1315 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1316 Type emulatedElemTy = op.getType().getElementType();
1317 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1318 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1319
1320 // Check per-element alignment.
1321 if (containerBits % emulatedBits != 0) {
1322 return rewriter.notifyMatchFailure(
1323 op, "impossible to pack emulated elements into container elements "
1324 "(bit-wise misalignment)");
1325 }
1326 int emulatedPerContainerElem = containerBits / emulatedBits;
1327
1328 auto origElements = op.getVectorType().getNumElements();
1329
1330 // Note, per-element-alignment was already verified above.
1331 bool isDivisibleInSize =
1332 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1333
1334 // Pad the padding value with 0s on the left. These bits are discarded and
1335 // thus their values don't matter.
1336 Value padding = adaptor.getPadding();
1337 if (!padding.getType().isInteger()) {
1338 padding = arith::BitcastOp::create(
1339 rewriter, loc,
1340 IntegerType::get(rewriter.getContext(),
1341 padding.getType().getIntOrFloatBitWidth()),
1342 padding);
1343 }
1344 auto newPadding =
1345 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1346
1347 auto stridedMetadata =
1348 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1349
1350 OpFoldResult linearizedIndices;
1351 memref::LinearizedMemRefInfo linearizedInfo;
1352 std::tie(linearizedInfo, linearizedIndices) =
1354 rewriter, loc, emulatedBits, containerBits,
1355 stridedMetadata.getConstifiedMixedOffset(),
1356 stridedMetadata.getConstifiedMixedSizes(),
1357 stridedMetadata.getConstifiedMixedStrides(),
1358 getAsOpFoldResult(adaptor.getIndices()));
1359
1360 std::optional<int64_t> foldedIntraVectorOffset =
1361 isDivisibleInSize ? 0
1362 : getConstantIntValue(linearizedInfo.intraDataOffset);
1363
1364 int64_t maxIntraDataOffset =
1365 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1366 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1367 emulatedPerContainerElem);
1368
1369 auto newRead = vector::TransferReadOp::create(
1370 rewriter, loc, VectorType::get(numElements, containerElemTy),
1371 adaptor.getBase(),
1372 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1373 newPadding);
1374
1375 auto bitCast = vector::BitCastOp::create(
1376 rewriter, loc,
1377 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1378 newRead);
1379
1380 Value result = bitCast->getResult(0);
1381 if (!foldedIntraVectorOffset) {
1382 auto zeros = arith::ConstantOp::create(
1383 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1384 result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1385 linearizedInfo.intraDataOffset,
1386 origElements);
1387 } else if (!isDivisibleInSize) {
1389 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1390 }
1391 rewriter.replaceOp(op, result);
1392
1393 return success();
1394 }
1395};
1396} // end anonymous namespace
1397
1398//===----------------------------------------------------------------------===//
1399// RewriteBitCastOfTruncI
1400//===----------------------------------------------------------------------===//
1401
1402namespace {
1403
1404/// Helper struct to keep track of the provenance of a contiguous set of bits
1405/// in a source vector.
1406struct SourceElementRange {
1407 /// The index of the source vector element that contributes bits to *this.
1408 int64_t sourceElementIdx;
1409 /// The range of bits in the source vector element that contribute to *this.
1410 int64_t sourceBitBegin;
1411 int64_t sourceBitEnd;
1412};
1413
1414struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1415 /// Given the index of a SourceElementRange in the SourceElementRangeList,
1416 /// compute the amount of bits that need to be shifted to the left to get the
1417 /// bits in their final location. This shift amount is simply the sum of the
1418 /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1419 /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1420 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1421 int64_t res = 0;
1422 for (int64_t i = 0; i < shuffleIdx; ++i)
1423 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1424 return res;
1425 }
1426};
1427
1428/// Helper struct to enumerate the source elements and bit ranges that are
1429/// involved in a bitcast operation.
1430/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1431/// any 1-D vector shape and any source/target bitwidths.
1432/// This creates and holds a mapping of the form:
1433/// [dstVectorElementJ] ==
1434/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1435/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1436/// [0] = {0, [0-8)}
1437/// [1] = {0, [8-16)}
1438/// [2] = {0, [16-24)}
1439/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1440/// [0] = {0, [0, 10)}, {1, [0, 5)}
1441/// [1] = {1, [5, 10)}, {2, [0, 10)}
1442struct BitCastBitsEnumerator {
1443 BitCastBitsEnumerator(VectorType sourceVectorType,
1444 VectorType targetVectorType);
1445
1446 int64_t getMaxNumberOfEntries() {
1447 int64_t numVectors = 0;
1448 for (const auto &l : sourceElementRanges)
1449 numVectors = std::max(numVectors, (int64_t)l.size());
1450 return numVectors;
1451 }
1452
1453 VectorType sourceVectorType;
1454 VectorType targetVectorType;
1455 SmallVector<SourceElementRangeList> sourceElementRanges;
1456};
1457
1458/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1459/// advantage of high-level information to avoid leaving LLVM to scramble with
1460/// peephole optimizations.
1461/// BitCastBitsEnumerator encodes for each element of the target vector the
1462/// provenance of the bits in the source vector. We can "transpose" this
1463/// information to build a sequence of shuffles and bitwise ops that will
1464/// produce the desired result.
1465//
1466/// Consider the following motivating example:
1467/// ```
1468/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1469/// ```
1470//
1471/// BitCastBitsEnumerator contains the following information:
1472/// ```
1473/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1474/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1475/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1476/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1477/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1478/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1479/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1480/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1481/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1482/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1483/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1484/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1485/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1486/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1487/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1488/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1489/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1490/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1491/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1492/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1493/// ```
1494///
1495/// In the above, each row represents one target vector element and each
1496/// column represents one bit contribution from a source vector element.
1497/// The algorithm creates vector.shuffle operations (in this case there are 3
1498/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1499/// algorithm populates the bits as follows:
1500/// ```
1501/// src bits 0 ...
1502/// 1st shuffle |xxxxx |xx |...
1503/// 2nd shuffle | xxx| xxxxx |...
1504/// 3rd shuffle | | x|...
1505/// ```
1506//
1507/// The algorithm proceeds as follows:
1508/// 1. for each vector.shuffle, collect the source vectors that participate in
1509/// this shuffle. One source vector per target element of the resulting
1510/// vector.shuffle. If there is no source element contributing bits for the
1511/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1512/// 2 columns).
1513/// 2. represent the bitrange in the source vector as a mask. If there is no
1514/// source element contributing bits for the current vector.shuffle, take 0.
1515/// 3. shift right by the proper amount to align the source bitrange at
1516/// position 0. This is exactly the low end of the bitrange. For instance,
1517/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1518/// shift right by 3 to get the bits contributed by the source element #1
1519/// into position 0.
1520/// 4. shift left by the proper amount to to align to the desired position in
1521/// the result element vector. For instance, the contribution of the second
1522/// source element for the first row needs to be shifted by `5` to form the
1523/// first i8 result element.
1524///
1525/// Eventually, we end up building the sequence
1526/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1527/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1528/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1529struct BitCastRewriter {
1530 /// Helper metadata struct to hold the static quantities for the rewrite.
1531 struct Metadata {
1532 SmallVector<int64_t> shuffles;
1533 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1534 };
1535
1536 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1537
1538 /// Verify that general preconditions for the rewrite are met.
1539 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1540 VectorType preconditionType, Operation *op);
1541
1542 /// Precompute the metadata for the rewrite.
1543 SmallVector<BitCastRewriter::Metadata>
1544 precomputeMetadata(IntegerType shuffledElementType);
1545
1546 /// Rewrite one step of the sequence:
1547 /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1548 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1549 Value initialValue, Value runningResult,
1550 const BitCastRewriter::Metadata &metadata);
1551
1552private:
1553 /// Underlying enumerator that encodes the provenance of the bits in the each
1554 /// element of the result vector.
1555 BitCastBitsEnumerator enumerator;
1556};
1557
1558} // namespace
1559
1560[[maybe_unused]] static raw_ostream &
1562 for (const auto &l : vec) {
1563 for (auto it : llvm::enumerate(l)) {
1564 os << "{ " << it.value().sourceElementIdx << ": b@["
1565 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1566 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1567 }
1568 os << "\n";
1569 }
1570 return os;
1571}
1572
1573BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1574 VectorType targetVectorType)
1575 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1576
1577 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1578 "requires -D non-scalable vector type");
1579 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1580 "requires -D non-scalable vector type");
1581 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1582 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1583 LDBG() << "sourceVectorType: " << sourceVectorType;
1584
1585 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1586 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1587 LDBG() << "targetVectorType: " << targetVectorType;
1588
1589 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1590 (void)mostMinorSourceDim;
1591 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1592 "source and target bitwidths must match");
1593
1594 // Prepopulate one source element range per target element.
1595 sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1596 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1597 int64_t resultElement = resultBit / targetBitWidth;
1598 int64_t resultBitInElement = resultBit % targetBitWidth;
1599 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1600 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1601 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1602 targetBitWidth - resultBitInElement);
1603 sourceElementRanges[resultElement].push_back(
1604 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1605 resultBit += step;
1606 }
1607}
1608
1609BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1610 VectorType targetVectorType)
1611 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1612 LDBG() << "\n" << enumerator.sourceElementRanges;
1613}
1614
1615/// Verify that the precondition type meets the common preconditions for any
1616/// conversion.
1617static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1618 VectorType preconditionType,
1619 Operation *op) {
1620 if (!preconditionType || preconditionType.isScalable())
1621 return rewriter.notifyMatchFailure(op, "scalable vector");
1622
1623 // TODO: consider relaxing this restriction in the future if we find ways
1624 // to really work with subbyte elements across the MLIR/LLVM boundary.
1625 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1626 if (bitwidth % 8 != 0)
1627 return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1628
1629 return success();
1630}
1631
1632LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1633 VectorType preconditionType,
1634 Operation *op) {
1635 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1636 return rewriter.notifyMatchFailure(op, "types are not vector");
1637
1638 if (!preconditionType || preconditionType.getRank() != 1)
1639 return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1640
1641 return commonConversionPrecondition(rewriter, preconditionType, op);
1642}
1643
1644/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1645///
1646/// Alignment means that `subByteVecTy` can be packed into a vector of
1647/// `containerTy` elements. More specifically:
1648/// 1. The bit-width of `containerTy` is a multiple of the
1649/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1650/// this multiple is 4.
1651/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1652/// elements in `subByteVecTy`.
1653///
1654/// EXAMPLE 1:
1655/// `subByteVecTy = vector<2xi4>`, and
1656/// `containerTy = i16`
1657///
1658/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1659///
1660/// EXAMPLE 2:
1661/// `subByteVecTy = vector<3xi4>`, and
1662/// `containerTy = i16`
1663///
1664/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1665///
1666/// EXAMPLE 3:
1667/// `subByteVecTy = vector<3xi3>`, and
1668/// `containerTy = i16`
1669///
1670/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1671///
1672/// NOTE: This method assumes that common conversion preconditions are met. In
1673/// particular, `containerTy` is assumed to be a
1674/// multi-byte scalar type (e.g., i8, i16, i32).
1675static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1676 VectorType subByteVecTy,
1677 Type containerTy,
1678 Operation *op) {
1679 assert(containerTy.isIntOrFloat() &&
1680 "container element type is not a scalar");
1681
1682 // TODO: This is validating the inputs rather than checking the conditions
1683 // documented above. Replace with an assert.
1684 if (!subByteVecTy)
1685 return rewriter.notifyMatchFailure(op, "not a vector!");
1686
1687 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1688 unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1689
1690 // Enforced by the common pre-conditions.
1691 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1692
1693 // TODO: Add support other widths (when/if needed)
1694 if (subByteBits != 2 && subByteBits != 4)
1695 return rewriter.notifyMatchFailure(
1696 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1697
1698 // Condition 1 ("per-element" alignment)
1699 if (containerBits % subByteBits != 0)
1700 return rewriter.notifyMatchFailure(op, "unalagined element types");
1701
1702 // Condition 2 ("full" alignment)
1703 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1704 return rewriter.notifyMatchFailure(
1705 op, "not possible to fit this sub-byte vector type into a vector of "
1706 "the given multi-byte type");
1707
1708 return success();
1709}
1710
1711SmallVector<BitCastRewriter::Metadata>
1712BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1713 SmallVector<BitCastRewriter::Metadata> result;
1714 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1715 shuffleIdx < e; ++shuffleIdx) {
1716 SmallVector<int64_t> shuffles;
1717 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1718
1719 // Create the attribute quantities for the shuffle / mask / shift ops.
1720 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1721 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1722 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1723 : 0;
1724 shuffles.push_back(sourceElement);
1725
1726 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1727 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1728 : 0;
1729 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1730 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1731 : 0;
1732 IntegerAttr mask = IntegerAttr::get(
1733 shuffledElementType,
1734 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1735 bitLo, bitHi));
1736 masks.push_back(mask);
1737
1738 int64_t shiftRight = bitLo;
1739 shiftRightAmounts.push_back(
1740 IntegerAttr::get(shuffledElementType, shiftRight));
1741
1742 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1743 shiftLeftAmounts.push_back(
1744 IntegerAttr::get(shuffledElementType, shiftLeft));
1745 }
1746
1747 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1748 }
1749 return result;
1750}
1751
1752Value BitCastRewriter::genericRewriteStep(
1753 PatternRewriter &rewriter, Location loc, Value initialValue,
1754 Value runningResult, const BitCastRewriter::Metadata &metadata) {
1755 // Create vector.shuffle from the metadata.
1756 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1757 initialValue, metadata.shuffles);
1758
1759 // Intersect with the mask.
1760 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1761 auto constOp = arith::ConstantOp::create(
1762 rewriter, loc,
1763 DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1764 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1765
1766 // Align right on 0.
1767 auto shiftRightConstantOp = arith::ConstantOp::create(
1768 rewriter, loc,
1769 DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1770 Value shiftedRight =
1771 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1772
1773 // Shift bits left into their final position.
1774 auto shiftLeftConstantOp = arith::ConstantOp::create(
1775 rewriter, loc,
1776 DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1777 Value shiftedLeft =
1778 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1779
1780 runningResult =
1781 runningResult
1782 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1783 : shiftedLeft;
1784
1785 return runningResult;
1786}
1787
1788/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1789/// Where aligned means it satisfies the alignedConversionPreconditions.
1790///
1791/// Example:
1792/// vector<16x16xi2> -> vector<16x4xi8>
1793/// vector<16x16xi4> -> vector<16x8xi8>
1795 Value subByteVec) {
1796 auto srcVecType = cast<VectorType>(subByteVec.getType());
1797 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1798 assert(8 % srcBitwidth == 0 &&
1799 "Unsupported sub-byte type (not a divisor of i8)");
1800 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1801 SmallVector<int64_t> vecShape(srcVecType.getShape());
1802 // Adjust last dimension of the vector, so the total size remains the same.
1803 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1804 auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1805 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1806}
1807
1808/// Extracts a signed N-bit sequence from each element of a vector of bytes,
1809/// starting at the specified bit index.
1810/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1811///
1812/// Example for a single element:
1813/// Extract numBits=2 starting at bitIdx=2
1814/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1815/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1816/// target = [. . . . ^ ^ . .]
1817///
1818/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1819/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1820///
1821/// src = [01 01 11 10]
1822/// shl = arith.shl(src, 4) -> [11 10 00 00]
1823/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1825 Location loc, Value src,
1826 int bitIdx, int numBits) {
1827 auto srcType = cast<VectorType>(src.getType());
1828 Value shl = src;
1829 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1830 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1831 "Invalid bitIdx range");
1832 if (bitsToShiftLeft != 0) {
1833 Value shiftLeftValues = arith::ConstantOp::create(
1834 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1835 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1836 }
1837
1838 int8_t bitsToShiftRight = 8 - numBits;
1839 Value shiftRightValues = arith::ConstantOp::create(
1840 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1841 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1842 return shr;
1843}
1844
1845/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1846/// starting at the specified bit index.
1847/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1848///
1849/// Example for a single element:
1850/// Extract numBits=2 starting at bitIdx=2
1851/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1852/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1853/// target = [. . . . ^ ^ . .]
1854///
1855/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1856/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1857///
1858/// src = [01 01 10 10]
1859/// mask = [00 00 00 11]
1860/// shr = arith.shrui(src, 2) = [00 01 01 10]
1861/// result = arith.andi(shr, mask) = [00 00 00 10]
1862/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1863/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1864/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1865/// left when the index is 0.
1867 Location loc, Value src,
1868 int bitIdx, int numBits) {
1869 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1870 "Invalid bitIdx range");
1871 auto srcType = cast<VectorType>(src.getType());
1872 int8_t bitsToShiftRight = bitIdx;
1873 Value shr = src;
1874 if (bitsToShiftRight != 0) {
1875 Value shiftRightValues = arith::ConstantOp::create(
1876 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1877 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1878 }
1879 if (bitIdx + numBits == 8) {
1880 return shr;
1881 }
1882 uint8_t lowBitsMask = (1 << numBits) - 1;
1883 Value lowBitsMaskValues = arith::ConstantOp::create(
1884 rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1885 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1886}
1887
1889 std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1890
1891/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1892/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1894 Value srcValue, const ExtractNBitsFn &extFn) {
1895 [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1896 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1897 "Expected i4 type");
1898
1899 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1900 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1901
1902 // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1903 // byte are place in one vector and the high i4 elements in another vector.
1904 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1905 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1906
1907 // 3. Interleave low and high i8 elements.
1908 return vector::InterleaveOp::create(rewriter, loc, low, high);
1909}
1910
1911/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1912/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1914 Value srcValue, const ExtractNBitsFn &extFn) {
1915 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1916 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1917 "Expected i2 type");
1918
1919 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1920 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1921
1922 // 2. Extract each i2 element
1923 // Positon 0 (bits 0-1)
1924 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1925 // Position 1 (bits 2-3)
1926 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1927 // Position 2 (bits 4-5)
1928 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1929 // Position 3 (bits 6-7)
1930 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1931
1932 // 3. Interleave all 4 elements by first interleaving
1933 // even elements and then odd
1934 // vec0 = [0,0,0,0],...
1935 // vec1 = [1,1,1,1],...
1936 // vec2 = [2,2,2,2],...
1937 // vec3 = [3,3,3,3],...
1938 // 02 = [0,2,0,2,0,2,0,2],...
1939 // 13 = [1,3,1,3,1,3,1,3],...
1940 // 0213 = [0,1,2,3,...],...
1941 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1942 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1943 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1944 interleave13);
1945}
1946
1947/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1948/// ops to avoid leaving LLVM to scramble with peephole optimizations.
1950 Value srcValue) {
1951 VectorType srcVecType = cast<VectorType>(srcValue.getType());
1952 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1953 "Expected i8 type");
1954
1955 // 1. De-interleave low and high i8 elements.
1956 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1957
1958 // 2. Zero out the upper side of each low i8 element.
1959 constexpr int8_t i8LowBitMask = 0x0F;
1960 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1961 Value zeroOutMask = arith::ConstantOp::create(
1962 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1963 Value zeroOutLow = arith::AndIOp::create(
1964 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1965
1966 // 3. Move high i4 values to upper side of the byte.
1967 constexpr int8_t bitsToShift = 4;
1968 auto shiftValues = arith::ConstantOp::create(
1969 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1970 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1971 shiftValues);
1972
1973 // 4. Merge high and low i4 values.
1974 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1975
1976 // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1977 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1978 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1979}
1980
1981namespace {
1982/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1983/// advantage of high-level information to avoid leaving LLVM to scramble with
1984/// peephole optimizations.
1985struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1986 using Base::Base;
1987
1988 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1989 PatternRewriter &rewriter) const override {
1990 // The source must be a trunc op.
1991 auto truncOp =
1992 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1993 if (!truncOp)
1994 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1995
1996 // Set up the BitCastRewriter and verify the precondition.
1997 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1998 VectorType targetVectorType = bitCastOp.getResultVectorType();
1999 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2000 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
2001 return failure();
2002
2003 // Perform the rewrite.
2004 Value truncValue = truncOp.getIn();
2005 auto shuffledElementType =
2006 cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
2007 Value runningResult;
2008 for (const BitCastRewriter ::Metadata &metadata :
2009 bcr.precomputeMetadata(shuffledElementType)) {
2010 runningResult = bcr.genericRewriteStep(
2011 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
2012 }
2013
2014 // Finalize the rewrite.
2015 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
2016 shuffledElementType.getIntOrFloatBitWidth();
2017 if (narrowing) {
2018 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
2019 rewriter.replaceOp(bitCastOp, runningResult);
2020 } else {
2021 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2022 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2023 }
2024 } else {
2025 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
2026 rewriter.replaceOp(bitCastOp, runningResult);
2027 } else {
2028 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2029 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
2030 }
2031 }
2032
2033 return success();
2034 }
2035};
2036} // namespace
2037
2038//===----------------------------------------------------------------------===//
2039// RewriteExtOfBitCast
2040//===----------------------------------------------------------------------===//
2041
2042namespace {
2043/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
2044/// take advantage of high-level information to avoid leaving LLVM to scramble
2045/// with peephole optimizations.
2046template <typename ExtOpType>
2047struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2048 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2049
2050 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2051 : OpRewritePattern<ExtOpType>(context, benefit) {}
2052
2053 LogicalResult matchAndRewrite(ExtOpType extOp,
2054 PatternRewriter &rewriter) const override {
2055 // The source must be a bitcast op.
2056 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2057 if (!bitCastOp)
2058 return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2059
2060 // Set up the BitCastRewriter and verify the precondition.
2061 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2062 VectorType targetVectorType = bitCastOp.getResultVectorType();
2063 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2064 if (failed(bcr.commonPrecondition(
2065 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2066 return failure();
2067
2068 // Perform the rewrite.
2069 Value runningResult;
2070 Value sourceValue = bitCastOp.getSource();
2071 auto shuffledElementType =
2072 cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2073 for (const BitCastRewriter::Metadata &metadata :
2074 bcr.precomputeMetadata(shuffledElementType)) {
2075 runningResult = bcr.genericRewriteStep(
2076 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2077 }
2078
2079 // Finalize the rewrite.
2080 bool narrowing =
2081 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2082 shuffledElementType.getIntOrFloatBitWidth();
2083 if (narrowing) {
2084 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2085 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2086 } else {
2087 rewriter.replaceOpWithNewOp<ExtOpType>(
2088 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2089 }
2090
2091 return success();
2092 }
2093};
2094
2095/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2096/// bitwise ops that take advantage of high-level information to avoid leaving
2097/// LLVM to scramble with peephole optimizations. Templated to choose between
2098/// signed and unsigned conversions.
2099///
2100/// EXAMPLE 1 (signed):
2101/// arith.extsi %in : vector<8xi4> to vector<8xi32>
2102/// is rewriten as:
2103/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2104/// %1 = arith.shli %0, 4 : vector<4xi8>
2105/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2106/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2107/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2108/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2109///
2110/// EXAMPLE 2 (fp):
2111/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2112/// is rewriten as:
2113/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2114/// %1 = arith.shli %0, 4 : vector<4xi8>
2115/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2116/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2117/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2118/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2119///
2120/// EXAMPLE 3 (unsigned):
2121/// arith.extui %in : vector<8xi4> to vector<8xi32>
2122/// is rewritten as:
2123/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2124/// %1 = arith.andi %0, 15 : vector<4xi8>
2125/// %2 = arith.shrui %0, 4 : vector<4xi8>
2126/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2127/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2128///
2129template <typename ConversionOpType, bool isSigned>
2130struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2131 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2132
2133 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2134 PatternRewriter &rewriter) const override {
2135 // Verify the preconditions.
2136 Value srcValue = conversionOp.getIn();
2137 VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2138 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2139
2140 if (failed(
2141 commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2142 return failure();
2143
2144 // Check general alignment preconditions.
2146 rewriter, srcVecType,
2147 /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2148 return failure();
2149
2150 // Perform the rewrite.
2151 Location loc = conversionOp.getLoc();
2152 const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2154 Value subByteExt;
2155 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2156 case 2:
2157 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2158 break;
2159 case 4:
2160 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2161 break;
2162 default:
2163 return failure();
2164 }
2165
2166 // Finalize the rewrite.
2167 rewriter.replaceOpWithNewOp<ConversionOpType>(
2168 conversionOp, conversionOp.getType(), subByteExt);
2169 return success();
2170 }
2171};
2172
2173/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2174/// bitwise ops that take advantage of high-level information to avoid leaving
2175/// LLVM to scramble with peephole optimizations.
2176///
2177/// For example:
2178/// arith.trunci %in : vector<8xi32> to vector<8xi4>
2179///
2180/// is rewriten as:
2181///
2182/// %cst = arith.constant dense<15> : vector<4xi8>
2183/// %cst_0 = arith.constant dense<4> : vector<4xi8>
2184/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2185/// %2 = arith.andi %0, %cst : vector<4xi8>
2186/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2187/// %4 = arith.ori %2, %3 : vector<4xi8>
2188/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2189///
2190struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2191 using Base::Base;
2192
2193 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2194 PatternRewriter &rewriter) const override {
2195 // Verify the preconditions.
2196 Value srcValue = truncOp.getIn();
2197 auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2198 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2199 if (!srcVecType || !dstVecType)
2200 return failure();
2201
2202 if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2203 return failure();
2204
2205 // TODO: Add support for truncating to i2.
2206 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2207 return failure();
2208
2209 // Check general alignment preconditions. We invert the src/dst type order
2210 // to reuse the existing precondition logic.
2212 rewriter, dstVecType,
2213 /*containerTy=*/rewriter.getI8Type(), truncOp)))
2214 return failure();
2215
2216 // Create a new iX -> i8 truncation op.
2217 Location loc = truncOp.getLoc();
2218 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2219 Value i8TruncVal =
2220 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2221
2222 // Rewrite the i8 -> i4 truncation part.
2223 Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2224
2225 // Finalize the rewrite.
2226 rewriter.replaceOp(truncOp, subByteTrunc);
2227 return success();
2228 }
2229};
2230
2231/// Rewrite a sub-byte vector transpose into a sequence of instructions that
2232/// perform the transpose on wider (byte) element types.
2233///
2234/// EXAMPLE:
2235/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2236///
2237/// is rewritten as:
2238///
2239/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2240/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2241/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2242///
2243struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2244 using Base::Base;
2245
2246 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2247 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2248
2249 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2250 PatternRewriter &rewriter) const override {
2251 // Precondition: sub-byte integer transpose.
2252 constexpr unsigned minNativeBitwidth = 8;
2253 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2254 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2255 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2256 return rewriter.notifyMatchFailure(transposeOp,
2257 "not a sub-byte transpose");
2258 }
2259
2260 // Perform the rewrite.
2261 Location loc = transposeOp.getLoc();
2262 // Signed/unsigned interpretation shouldn't matter here as we are just
2263 // transposing the elements and truncating them back to the original size.
2264 // TODO: Use unsigned extension (more efficient) when emulation or backend
2265 // support is available.
2266 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2267 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2268 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2269 transposeOp.getVector());
2270 Value newTranspose = vector::TransposeOp::create(
2271 rewriter, loc, extOp, transposeOp.getPermutation());
2272 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2273 rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2274 newTranspose);
2275 return success();
2276 }
2277};
2278
2279} // namespace
2280
2281//===----------------------------------------------------------------------===//
2282// Public Interface Definition
2283//===----------------------------------------------------------------------===//
2284
2285// The emulated type is inferred from the converted memref type.
2286void vector::populateVectorNarrowTypeEmulationPatterns(
2287 const arith::NarrowTypeEmulationConverter &typeConverter,
2288 RewritePatternSet &patterns, bool disableAtomicRMW, bool assumeAligned) {
2289 // Populate `vector.*` conversion patterns.
2290 // TODO: #119553 support atomicity
2291 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2292 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2293 typeConverter, patterns.getContext());
2294
2295 // Populate `vector.*` store conversion patterns. The caller can choose
2296 // to avoid emitting atomic operations and reduce it to read-modify-write
2297 // sequence for stores if it is known there are no thread contentions.
2298 patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW,
2299 assumeAligned);
2300}
2301
2302void vector::populateVectorNarrowTypeRewritePatterns(
2303 RewritePatternSet &patterns, PatternBenefit benefit) {
2304 // TODO: Document what the emulated type is.
2305 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2306 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2307 benefit);
2308
2309 // Patterns for aligned cases. We set higher priority as they are expected to
2310 // generate better performance for aligned cases.
2311 // The container type is always i8.
2312 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2313 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2314 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2315 benefit.getBenefit() + 1);
2316 // The container type is always i8.
2317 patterns
2318 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2319 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2320 patterns.getContext(), benefit.getBenefit() + 1);
2321}
2322
2323// The container type is always i8.
2324void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2325 RewritePatternSet &patterns, PatternBenefit benefit) {
2326 patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2327}
2328
2329void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2330 arith::NarrowTypeEmulationConverter &typeConverter,
2331 RewritePatternSet &patterns) {
2333 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2334}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IntegerType getI4Type()
Definition Builders.cpp:61
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
IntegerType getI8Type()
Definition Builders.cpp:63
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_type_range getResultTypes()
Definition Operation.h:428
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.