MLIR 23.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate
10// narrow types that are not supported by the target hardware, e.g. i4
11// ("emulated type"), using wider types, e.g. i8 ("container type").
12//
13/// Currently, only power-of-two integer types are supported. These are
14/// converted to wider integers that are either 8 bits wide or wider.
15///
16/// TODO: Support for non-powers-of-two.
17//===----------------------------------------------------------------------===//
18
32#include "mlir/IR/Value.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Support/raw_ostream.h"
38#include <cstdint>
39#include <optional>
40
42
43using namespace mlir;
44
45#define DEBUG_TYPE "vector-narrow-type-emulation"
46
49
50//===----------------------------------------------------------------------===//
51// Utils
52//===----------------------------------------------------------------------===//
53
54/// Returns a compressed mask for the emulated vector. For example, when
55/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
56/// elements span two dest elements), this method compresses `vector<8xi1>`
57/// into `vector<2xi1>`.
58///
59/// The compressed/output mask value is set iff any mask in the corresponding
60/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
61/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
62/// following mask:
63///
64/// %mask = [1, 1, 0, 0, 0, 0]
65///
66/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
67/// will be added in the back to make the number of elements a multiple of
68/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
69///
70/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
71///
72/// then it will return the following new compressed mask:
73///
74/// %mask = [1, 1, 0, 0]
75///
76/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
77/// `numSrcElemsPerDest`.
78static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
79 Location loc, Value mask,
80 int numSrcElems,
81 int numSrcElemsPerDest,
82 int numFrontPadElems = 0) {
83
84 assert(numFrontPadElems < numSrcElemsPerDest &&
85 "numFrontPadElems must be less than numSrcElemsPerDest");
86
87 auto numDestElems =
88 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
89 numSrcElemsPerDest;
90
91 Operation *maskOp = mask.getDefiningOp();
93 // TODO: add support to `vector.broadcast`.
94 // Finding the mask creation operation.
95 while (maskOp &&
96 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
97 maskOp)) {
98 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99 maskOp = extractOp.getSource().getDefiningOp();
100 extractOps.push_back(extractOp);
101 }
102 }
103
104 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
105 maskOp))
106 return failure();
107
108 // Computing the "compressed" mask. All the emulation logic (i.e. computing
109 // new mask index) only happens on the last dimension of the vectors.
110 SmallVector<int64_t> maskShape(
111 cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
112 maskShape.back() = numDestElems;
113 auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
114 std::optional<Operation *> newMask =
116 .Case(
117 [&](vector::CreateMaskOp createMaskOp)
118 -> std::optional<Operation *> {
119 OperandRange maskOperands = createMaskOp.getOperands();
120 // The `vector.create_mask` op creates a mask arrangement
121 // without any zeros at the front. Also, because
122 // `numFrontPadElems` is strictly smaller than
123 // `numSrcElemsPerDest`, the compressed mask generated by
124 // padding the original mask by `numFrontPadElems` will not
125 // have any zeros at the front as well.
126 AffineExpr s0;
127 bindSymbols(rewriter.getContext(), s0);
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131 rewriter, loc, s0, origIndex);
132 SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133 newMaskOperands.push_back(
134 getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135 return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
136 newMaskOperands);
137 })
138 .Case([&](vector::ConstantMaskOp constantMaskOp)
139 -> std::optional<Operation *> {
140 // Take the shape of mask, compress its trailing dimension:
141 SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
142 int64_t &maskIndex = maskDimSizes.back();
143 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
144 numSrcElemsPerDest);
145 return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
146 maskDimSizes);
147 })
148 .Case([&](arith::ConstantOp constantOp)
149 -> std::optional<Operation *> {
150 // TODO: Support multiple dimensions.
151 if (maskShape.size() != 1)
152 return std::nullopt;
153 // Rearrange the original mask values to cover the whole potential
154 // loading region. For example, in the case of using byte-size for
155 // emulation, given the following mask:
156 //
157 // %mask = [0, 1, 0, 1, 0, 0]
158 //
159 // With front offset of 1, the mask will be padded 0s in the front
160 // and back so that:
161 // 1. It is aligned with the effective loading bits
162 // 2. Its length is multiple of `numSrcElemPerDest` (and the total
163 // coverage size is mulitiple of bytes). The new mask will be like
164 // this before compressing:
165 //
166 // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
167 auto originalMask =
168 cast<DenseIntElementsAttr>(constantOp.getValue());
169 SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
170 paddedMaskValues.append(originalMask.template value_begin<bool>(),
171 originalMask.template value_end<bool>());
172 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
173
174 // Compressing by combining every `numSrcElemsPerDest` elements:
175 SmallVector<bool> compressedMaskValues;
176 for (size_t i = 0; i < paddedMaskValues.size();
177 i += numSrcElemsPerDest) {
178 bool combinedValue = false;
179 for (int j = 0; j < numSrcElemsPerDest; ++j) {
180 combinedValue |= paddedMaskValues[i + j];
181 }
182 compressedMaskValues.push_back(combinedValue);
183 }
184 return arith::ConstantOp::create(
185 rewriter, loc,
186 DenseElementsAttr::get(newMaskType, compressedMaskValues));
187 });
188
189 if (!newMask)
190 return failure();
191
192 while (!extractOps.empty()) {
193 newMask =
194 vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
195 extractOps.back().getMixedPosition());
196 extractOps.pop_back();
197 }
198
199 return *newMask;
200}
201
202/// Extracts 1-D subvector from a 1-D vector.
203///
204/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
205/// from `src`, starting at `offset`. The result is also a rank-1 vector:
206///
207/// vector<numElemsToExtract x !elemType>
208///
209/// (`!elType` is the element type of the source vector). As `offset` is a known
210/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
211///
212/// EXAMPLE:
213/// %res = vector.extract_strided_slice %src
214/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
216 Value src, int64_t offset,
217 int64_t numElemsToExtract) {
218 auto vectorType = cast<VectorType>(src.getType());
219 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
220 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
221 "subvector out of bounds");
222
223 // When extracting all available elements, just use the source vector as the
224 // result.
225 if (vectorType.getNumElements() == numElemsToExtract)
226 return src;
227
228 auto offsets = rewriter.getI64ArrayAttr({offset});
229 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
230 auto strides = rewriter.getI64ArrayAttr({1});
231
232 auto resultVectorType =
233 VectorType::get({numElemsToExtract}, vectorType.getElementType());
234 return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
235 src, offsets, sizes, strides)
236 ->getResult(0);
237}
238
239/// Inserts 1-D subvector into a 1-D vector.
240///
241/// Inserts the input rank-1 source vector into the destination vector starting
242/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
243/// `vector.insert_strided_slice`.
244///
245/// EXAMPLE:
246/// %res = vector.insert_strided_slice %src, %dest
247/// {offsets = [%offset], strides [1]}
249 Value src, Value dest, int64_t offset) {
250 [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
251 [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
254
255 // If overwritting the destination vector, just return the source.
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257 return src;
258
259 auto offsets = rewriter.getI64ArrayAttr({offset});
260 auto strides = rewriter.getI64ArrayAttr({1});
261 return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
262 dest, offsets, strides);
263}
264
265/// Extracts 1-D subvector from a 1-D vector.
266///
267/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268/// from `src`, starting at `offset`. The result is also a rank-1 vector:
269///
270/// vector<numElemsToExtact x !elType>
271///
272/// (`!elType` is the element type of the source vector). As `offset` is assumed
273/// to be a _dynamic_ SSA value, this helper method generates a sequence of
274/// `vector.extract` + `vector.insert` pairs.
275///
276/// EXAMPLE:
277/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279/// %c1 = arith.constant 1 : index
280/// %idx2 = arith.addi %offset, %c1 : index
281/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283/// (...)
285 Value src, Value dest,
286 OpFoldResult offset,
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast<VectorType>(src.getType());
289 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290 // NOTE: We are unable to take the offset into account in the following
291 // assert, hence its still possible that the subvector is out-of-bounds even
292 // if the condition is true.
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
295
296 // When extracting all available elements, just use the source vector as the
297 // result.
298 if (srcVecTy.getNumElements() == numElemsToExtract)
299 return src;
300
301 for (int i = 0; i < numElemsToExtract; ++i) {
302 Value extractLoc =
303 (i == 0) ? dyn_cast<Value>(offset)
304 : arith::AddIOp::create(
305 rewriter, loc, rewriter.getIndexType(),
306 dyn_cast<Value>(offset),
307 arith::ConstantIndexOp::create(rewriter, loc, i));
308 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
309 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
310 }
311 return dest;
312}
313
314/// Inserts 1-D subvector into a 1-D vector.
315///
316/// Inserts the input rank-1 source vector into the destination vector starting
317/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
318/// uses a sequence of `vector.extract` + `vector.insert` pairs.
319///
320/// EXAMPLE:
321/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
322/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
323/// %c1 = arith.constant 1 : index
324/// %idx2 = arith.addi %offset, %c1 : index
325/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
326/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
327/// (...)
329 Value src, Value dest,
330 OpFoldResult offset,
331 int64_t numElemsToInsert) {
332 auto srcVecTy = cast<VectorType>(src.getType());
333 auto destVecTy = cast<VectorType>(dest.getType());
334 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
335 "expected source and dest to be rank-1 vector types");
336 (void)srcVecTy;
337 (void)destVecTy;
338 assert(numElemsToInsert > 0 &&
339 "the number of elements to insert must be greater than 0");
340 // NOTE: We are unable to take the offset into account in the following
341 // assert, hence its still possible that the subvector is out-of-bounds even
342 // if the condition is true.
343 assert(numElemsToInsert <= destVecTy.getNumElements() &&
344 "subvector out of bounds");
345
346 Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
347 for (int64_t i = 0; i < numElemsToInsert; ++i) {
348 auto insertLoc =
349 i == 0 ? destOffsetVal
350 : arith::AddIOp::create(
351 rewriter, loc, rewriter.getIndexType(), destOffsetVal,
352 arith::ConstantIndexOp::create(rewriter, loc, i));
353 auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
354 dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
355 }
356 return dest;
357}
358
359/// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
360///
361/// Specifically, use `containerElemTy` for loading a vector of
362/// `emulatedElemTy`. The load location is given by `base` and
363/// `linearizedIndices`, and the load size is given by
364/// `numEmulatedElementsToLoad`.
366 Value base,
367 OpFoldResult linearizedIndices,
368 int64_t numContainerElemsToLoad,
369 Type emulatedElemTy,
370 Type containerElemTy) {
371 auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
372 emulatedElemTy.getIntOrFloatBitWidth();
373 auto newLoad = vector::LoadOp::create(
374 rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
375 base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
376 return vector::BitCastOp::create(
377 rewriter, loc,
378 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
379 emulatedElemTy),
380 newLoad);
381}
382
383/// Downcast two values to `downcastType`, then select values
384/// based on `mask`, and casts the result to `upcastType`.
386 VectorType downcastType,
387 VectorType upcastType, Value mask,
388 Value trueValue, Value falseValue) {
389 assert(
390 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
391 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
392 "expected input and output number of bits to match");
393 if (trueValue.getType() != downcastType) {
394 trueValue =
395 vector::BitCastOp::create(builder, loc, downcastType, trueValue);
396 }
397 if (falseValue.getType() != downcastType) {
398 falseValue =
399 vector::BitCastOp::create(builder, loc, downcastType, falseValue);
400 }
401 Value selectedType =
402 arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
403 // Upcast the selected value to the new type.
404 return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
405}
406
407/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
408/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
409/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
410/// which elements to store.
411///
412/// Inputs:
413/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
414/// storeIdx = 2
415/// valueToStore = |3|3|3|3| : vector<4xi2>
416/// mask = |0|0|1|1| : vector<4xi1>
417///
418/// Result:
419/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
420static void atomicRMW(OpBuilder &builder, Location loc,
421 MemRefValue linearizedMemref, Value storeIdx,
422 VectorValue valueToStore, Value mask) {
423 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
424
425 // Create an atomic load-modify-write region using
426 // `memref.generic_atomic_rmw`.
427 auto atomicOp = memref::GenericAtomicRMWOp::create(
428 builder, loc, linearizedMemref, ValueRange{storeIdx});
429 Value origValue = atomicOp.getCurrentValue();
430
431 OpBuilder::InsertionGuard guard(builder);
432 builder.setInsertionPointToStart(atomicOp.getBody());
433
434 // Load the original value from memory, and cast it to the original element
435 // type.
436 auto oneElemVecType = VectorType::get({1}, origValue.getType());
437 Value origVecValue = vector::FromElementsOp::create(
438 builder, loc, oneElemVecType, ValueRange{origValue});
439
440 // Construct the final masked value and yield it.
441 Value maskedValue =
442 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
443 oneElemVecType, mask, valueToStore, origVecValue);
444 auto scalarMaskedValue =
445 vector::ExtractOp::create(builder, loc, maskedValue, 0);
446 memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
447}
448
449/// Generate a non-atomic read-modify-write sequence for storing to the emulated
450/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
451static void nonAtomicRMW(OpBuilder &builder, Location loc,
452 MemRefValue linearizedMemref, Value linearizedIndex,
453 VectorValue valueToStore, Value mask) {
454 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
455
456 auto oneElemVecType =
457 VectorType::get({1}, linearizedMemref.getType().getElementType());
458 Value origVecValue =
459 vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
460 ValueRange{linearizedIndex});
461 origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
462 origVecValue);
463
464 Value maskedValue =
465 downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
466 oneElemVecType, mask, valueToStore, origVecValue);
467 vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
468 linearizedIndex);
469}
470
471/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
472/// and insert it into an empty vector at `insertOffset`.
473/// Inputs:
474/// vec_in = |0|1|2|3| : vector<4xi2>
475/// extractOffset = 1
476/// sliceNumElements = 2
477/// insertOffset = 2
478/// Output:
479/// vec_out = |0|0|1|2| : vector<4xi2>
480static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
482 int64_t extractOffset,
483 int64_t sliceNumElements,
484 int64_t insertOffset) {
485 assert(vector.getType().getRank() == 1 && "expected 1-D vector");
486 auto vectorElementType = vector.getType().getElementType();
487 // TODO: update and use `alignedConversionPrecondition` in the place of
488 // these asserts.
489 assert(
490 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
491 "sliceNumElements * vector element size must be less than or equal to 8");
492 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
493 "vector element must be a valid sub-byte type");
494 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
495 auto emptyByteVector = arith::ConstantOp::create(
496 rewriter, loc,
497 VectorType::get({emulatedPerContainerElem}, vectorElementType),
498 rewriter.getZeroAttr(
499 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
500 auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
501 extractOffset, sliceNumElements);
502 return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
503 insertOffset);
504}
505
506namespace {
507
508//===----------------------------------------------------------------------===//
509// ConvertVectorStore
510//===----------------------------------------------------------------------===//
511
512// Emulate `vector.store` using a multi-byte container type.
513//
514// The container type is obtained through Op adaptor and would normally be
515// generated via `NarrowTypeEmulationConverter`.
516//
517// EXAMPLE 1
518// (aligned store of i4, emulated using i8 as the container type)
519//
520// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
521//
522// is rewritten as:
523//
524// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
525// vector.store %src_bitcast, %dest_bitcast[%idx]
526// : memref<16xi8>, vector<4xi8>
527//
528// EXAMPLE 2
529// (unaligned store of i2, emulated using i8 as the container type)
530//
531// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
532//
533// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
534// is modelled using 3 bytes:
535//
536// Byte 0 Byte 1 Byte 2
537// +----------+----------+----------+
538// | oooooooo | ooooNNNN | NNoooooo |
539// +----------+----------+----------+
540//
541// N - (N)ew entries (i.e. to be overwritten by vector.store)
542// o - (o)ld entries (to be preserved)
543//
544// For the generated output in the non-atomic case, see:
545// * @vector_store_i2_const_index_two_partial_stores`
546// in:
547// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
548//
549// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
550// `false` to generate non-atomic RMW sequences.
551struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
552 using Base::Base;
553
554 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
555 : OpConversionPattern<vector::StoreOp>(context),
556 disableAtomicRMW(disableAtomicRMW) {}
557
558 LogicalResult
559 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561
562 if (op.getValueToStore().getType().getRank() != 1)
563 return rewriter.notifyMatchFailure(op,
564 "only 1-D vectors are supported ATM");
565
566 auto loc = op.getLoc();
567
568 auto valueToStore = cast<VectorValue>(op.getValueToStore());
569 auto containerElemTy =
570 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
571 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
572 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
573 int containerBits = containerElemTy.getIntOrFloatBitWidth();
574
575 // Check per-element alignment.
576 if (containerBits % emulatedBits != 0) {
577 return rewriter.notifyMatchFailure(
578 op, "impossible to pack emulated elements into container elements "
579 "(bit-wise misalignment)");
580 }
581 int emulatedPerContainerElem = containerBits / emulatedBits;
582
583 // Adjust the number of elements to store when emulating narrow types.
584 // Here only the 1-D vector store is considered, and the N-D memref types
585 // should be linearized.
586 // For example, to emulate i4 to i8, the following op:
587 //
588 // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
589 //
590 // can be replaced with
591 //
592 // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
593 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
594 // vector<4xi8>
595
596 auto origElements = valueToStore.getType().getNumElements();
597 // Note, per-element-alignment was already verified above.
598 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
599 // Do the trailing dim for source and destination match? If yes, then the
600 // corresponding index must be 0.
601 // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
602 // However, that makes some tests fail, so we need to audit first.
603 auto trailingDim = op.getBase().getType().getShape().back();
604 bool trailingDimsMatch =
605 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
606
607 auto stridedMetadata =
608 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
609
610 // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
611 // non-zero. As such, this is not needed.
612 OpFoldResult linearizedIndices;
613 memref::LinearizedMemRefInfo linearizedInfo;
614 std::tie(linearizedInfo, linearizedIndices) =
616 rewriter, loc, emulatedBits, containerBits,
617 stridedMetadata.getConstifiedMixedOffset(),
618 stridedMetadata.getConstifiedMixedSizes(),
619 stridedMetadata.getConstifiedMixedStrides(),
620 getAsOpFoldResult(adaptor.getIndices()));
621
622 std::optional<int64_t> foldedNumFrontPadElems =
623 (isDivisibleInSize && trailingDimsMatch)
624 ? 0
625 : getConstantIntValue(linearizedInfo.intraDataOffset);
626
627 if (!foldedNumFrontPadElems) {
628 return rewriter.notifyMatchFailure(
629 op, "subbyte store emulation: dynamic front padding size is "
630 "not yet implemented");
631 }
632
633 auto memrefBase = cast<MemRefValue>(adaptor.getBase());
634
635 // RMWs are not needed when:
636 // * no _partial_ stores are required.
637 // A partial store is defined as a store in which only a part of the
638 // container element is overwritten, e.g.
639 //
640 // Dest before (8 bits)
641 // +----------+
642 // | 11000000 |
643 // +----------+
644 //
645 // Dest after storing 0xF at offset 4 (in bits)
646 // +----------+
647 // | 11001111 |
648 // +----------+
649 //
650 // At a higher level, this translats to:
651 // 1. The source vector size (in bits) is a multiple of byte size.
652 // 2. The address of the store is aligned to the container type width
653 // boundary.
654 //
655 // EXAMPLE 1:
656 // Requires partial store:
657 // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
658 //
659 // EXAMPLE 2:
660 // Does not require a partial store:
661 // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
662 //
663 // TODO: Take linearizedInfo.linearizedOffset into account. This is
664 // currently not needed/used/exercised as all our tests set offset to 0.
665 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
666
667 if (!emulationRequiresPartialStores) {
668 // Basic case: storing full bytes.
669 auto numElements = origElements / emulatedPerContainerElem;
670 auto bitCast = vector::BitCastOp::create(
671 rewriter, loc, VectorType::get(numElements, containerElemTy),
672 op.getValueToStore());
673 rewriter.replaceOpWithNewOp<vector::StoreOp>(
674 op, bitCast.getResult(), memrefBase,
675 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
676 return success();
677 }
678
679 // Next, handle the case when sub-byte read-modify-write
680 // sequences are needed to emulate a vector store.
681 // Here is an example:
682 //
683 // Vector to store: vector<7xi2>
684 // Value to store: 11 11 11 11 11 11 11 (all ones)
685 //
686 // Destination: memref<12xi2>
687 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
688 //
689 // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
690 //
691 // Destination memref before:
692 //
693 // Byte 0 Byte 1 Byte 2
694 // +----------+----------+----------+
695 // | 00000000 | 00000000 | 00000000 |
696 // +----------+----------+----------+
697 //
698 // Destination memref after:
699 //
700 // Byte 0 Byte 1 Byte 2
701 // +----------+----------+----------+
702 // | 00001111 | 11111111 | 11000000 |
703 // +----------+----------+----------+
704 //
705 // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
706 // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
707 // requiring RMW access (atomicity is required).
708
709 // The index into the target memref we are storing to.
710 Value currentDestIndex =
711 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
712 // The index into the source vector we are currently processing.
713 auto currentSourceIndex = 0;
714
715 // Build a mask used for rmw.
716 auto subWidthStoreMaskType =
717 VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
718
719 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
720
721 // 1. Partial width store for the leading byte.
722 // When the store address is not aligned to emulated width boundary, deal
723 // with the unaligned part so that the rest elements are aligned to width
724 // boundary.
725 auto frontSubWidthStoreElem =
726 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
727 emulatedPerContainerElem;
728 if (frontSubWidthStoreElem > 0) {
729 SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
730 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
731 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
732 origElements, true);
733 frontSubWidthStoreElem = origElements;
734 } else {
735 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
736 *foldedNumFrontPadElems, true);
737 }
738 auto frontMask = arith::ConstantOp::create(
739 rewriter, loc,
740 DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
741
742 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
743 auto value =
744 extractSliceIntoByte(rewriter, loc, valueToStore, 0,
745 frontSubWidthStoreElem, *foldedNumFrontPadElems);
746
747 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
748 cast<VectorValue>(value), frontMask.getResult());
749 }
750
751 if (currentSourceIndex >= origElements) {
752 rewriter.eraseOp(op);
753 return success();
754 }
755
756 // Increment the destination index by 1 to align to the emulated width
757 // boundary.
758 auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
759 currentDestIndex = arith::AddIOp::create(
760 rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
761
762 // 2. Full width store for the inner output bytes.
763 // After the previous step, the store address is aligned to the emulated
764 // width boundary.
765 int64_t fullWidthStoreSize =
766 (origElements - currentSourceIndex) / emulatedPerContainerElem;
767 int64_t numNonFullWidthElements =
768 fullWidthStoreSize * emulatedPerContainerElem;
769 if (fullWidthStoreSize > 0) {
770 auto fullWidthStorePart = staticallyExtractSubvector(
771 rewriter, loc, valueToStore, currentSourceIndex,
772 numNonFullWidthElements);
773
774 auto originType = cast<VectorType>(fullWidthStorePart.getType());
775 auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
776 auto storeType = VectorType::get(
777 {originType.getNumElements() / emulatedPerContainerElem},
778 memrefElemType);
779 auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
780 fullWidthStorePart);
781 vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
782 currentDestIndex);
783
784 currentSourceIndex += numNonFullWidthElements;
785 currentDestIndex = arith::AddIOp::create(
786 rewriter, loc, rewriter.getIndexType(), currentDestIndex,
787 arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
788 }
789
790 // 3. Partial width store for the trailing output byte.
791 // It is needed when the residual length is smaller than the emulated width,
792 // which is not covered in step 2 above.
793 auto remainingElements = origElements - currentSourceIndex;
794 if (remainingElements != 0) {
795 auto subWidthStorePart =
796 extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
797 currentSourceIndex, remainingElements, 0);
798
799 // Generate back mask.
800 auto maskValues = SmallVector<bool>(emulatedPerContainerElem, false);
801 std::fill_n(maskValues.begin(), remainingElements, 1);
802 auto backMask = arith::ConstantOp::create(
803 rewriter, loc,
804 DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
805
806 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
807 cast<VectorValue>(subWidthStorePart), backMask.getResult());
808 }
809
810 rewriter.eraseOp(op);
811 return success();
812 }
813
814private:
815 const bool disableAtomicRMW;
816};
817
818//===----------------------------------------------------------------------===//
819// ConvertVectorMaskedStore
820//===----------------------------------------------------------------------===//
821
822/// Converts `vector.maskedstore` operations on narrow element types to work
823/// with wider, byte-aligned container types by adjusting the mask and using
824/// bitcasting.
825///
826/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
827/// (each `i8` container element holds two `i4` values) and storing with an
828/// adjusted mask .
829struct ConvertVectorMaskedStore final
830 : OpConversionPattern<vector::MaskedStoreOp> {
831 using Base::Base;
832
833 LogicalResult
834 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
835 ConversionPatternRewriter &rewriter) const override {
836
837 // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
838 if (op.getValueToStore().getType().getRank() != 1)
839 return rewriter.notifyMatchFailure(
840 op, "Memref in vector.maskedstore op must be flattened beforehand.");
841
842 auto loc = op.getLoc();
843 auto containerElemTy =
844 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
845 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
846 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
847 int containerBits = containerElemTy.getIntOrFloatBitWidth();
848
849 // Check per-element alignment.
850 if (containerBits % emulatedBits != 0) {
851 return rewriter.notifyMatchFailure(
852 op, "impossible to pack emulated elements into container elements "
853 "(bit-wise misalignment)");
854 }
855
856 int emulatedPerContainerElem = containerBits / emulatedBits;
857 int origElements = op.getValueToStore().getType().getNumElements();
858 if (origElements % emulatedPerContainerElem != 0)
859 return failure();
860
861 auto stridedMetadata =
862 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
863 OpFoldResult linearizedIndicesOfr;
864 memref::LinearizedMemRefInfo linearizedInfo;
865 std::tie(linearizedInfo, linearizedIndicesOfr) =
867 rewriter, loc, emulatedBits, containerBits,
868 stridedMetadata.getConstifiedMixedOffset(),
869 stridedMetadata.getConstifiedMixedSizes(),
870 stridedMetadata.getConstifiedMixedStrides(),
871 getAsOpFoldResult(adaptor.getIndices()));
872 Value linearizedIndices =
873 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
874
875 // Load the whole data and use arith.select to handle the corner cases.
876 //
877 // As an example, for this masked store of i4 values:
878 //
879 // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
880 //
881 // and given these input values:
882 //
883 // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
884 // %0[%c0, %c0] =
885 // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
886 // %val_to_store =
887 // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
888 //
889 // we'll have the following i4 output:
890 //
891 // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
892 //
893 // Emulating the above using i8 will give:
894 //
895 // %compressed_mask = [1, 1, 1, 0] (4 * i1)
896 // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
897 // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
898 // %select_using_shifted_mask =
899 // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
900 // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
901 //
902 // Using the compressed mask to store %packed_data results in expected
903 // output.
904 //
905 // FIXME: Make an example based on the comment above work (see #115460 for
906 // reproducer).
907 FailureOr<Operation *> newMask = getCompressedMaskOp(
908 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
909 if (failed(newMask))
910 return failure();
911
912 auto numElements = (origElements + emulatedPerContainerElem - 1) /
913 emulatedPerContainerElem;
914 auto newType = VectorType::get(numElements, containerElemTy);
915 auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
916 rewriter.getZeroAttr(newType));
917
918 auto newLoad = vector::MaskedLoadOp::create(
919 rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
920 newMask.value()->getResult(0), passThru);
921
922 auto newBitCastType =
923 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
924 Value valueToStore =
925 vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
926 valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
927 op.getValueToStore(), valueToStore);
928 valueToStore =
929 vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
930
931 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
932 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
933 valueToStore);
934 return success();
935 }
936};
937
938//===----------------------------------------------------------------------===//
939// ConvertVectorLoad
940//===----------------------------------------------------------------------===//
941
942/// Converts `vector.load` on narrow element types to work with
943/// wider, byte-aligned container types by adjusting load sizes and using
944/// bitcasting.
945///
946/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
947/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
948/// container holds two `i4` values) and bitcasting back.
949///
950/// There are cases where the number of elements to load is not byte-aligned. In
951/// those cases, loads are converted to byte-aligned, byte-sized loads and the
952/// target vector is extracted from the loaded vector.
953struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
954 using Base::Base;
955
956 LogicalResult
957 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter) const override {
959 // Prerequisite: memref in the vector.load op is flattened into 1-D.
960 if (op.getVectorType().getRank() != 1)
961 return rewriter.notifyMatchFailure(
962 op, "Memref in emulated vector ops must be flattened beforehand.");
963
964 auto loc = op.getLoc();
965 auto containerElemTy =
966 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
967 Type emulatedElemTy = op.getType().getElementType();
968 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
969 int containerBits = containerElemTy.getIntOrFloatBitWidth();
970
971 // Check per-element alignment.
972 if (containerBits % emulatedBits != 0) {
973 return rewriter.notifyMatchFailure(
974 op, "impossible to pack emulated elements into container elements "
975 "(bit-wise misalignment)");
976 }
977 int emulatedPerContainerElem = containerBits / emulatedBits;
978
979 // Adjust the number of elements to load when emulating narrow types,
980 // and then cast back to the original type with vector.bitcast op.
981 // For example, to emulate i4 to i8, the following op:
982 //
983 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
984 //
985 // can be replaced with
986 //
987 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
988 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
989 //
990 // There are cases where the number of elements to load is not byte-aligned,
991 // for example:
992 //
993 // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
994 //
995 // we will have to load extra bytes and extract the exact slice in between.
996 //
997 // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
998 // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
999 // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
1000 // = [1]}
1001 // : vector<8xi2> to vector<3xi2>
1002 //
1003 // TODO: Currently the extract_strided_slice's attributes must be known at
1004 // compile time as they must be constants.
1005
1006 auto origElements = op.getVectorType().getNumElements();
1007 // Note, per-element-alignment was already verified above.
1008 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1009
1010 auto stridedMetadata =
1011 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1012
1013 OpFoldResult linearizedIndices;
1014 memref::LinearizedMemRefInfo linearizedInfo;
1015 std::tie(linearizedInfo, linearizedIndices) =
1017 rewriter, loc, emulatedBits, containerBits,
1018 stridedMetadata.getConstifiedMixedOffset(),
1019 stridedMetadata.getConstifiedMixedSizes(),
1020 stridedMetadata.getConstifiedMixedStrides(),
1021 getAsOpFoldResult(adaptor.getIndices()));
1022
1023 std::optional<int64_t> foldedIntraVectorOffset =
1024 isDivisibleInSize ? 0
1025 : getConstantIntValue(linearizedInfo.intraDataOffset);
1026
1027 // Always load enough elements which can cover the original elements.
1028 int64_t maxintraDataOffset =
1029 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1030 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1031 emulatedPerContainerElem);
1032 Value result =
1033 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1034 numElements, emulatedElemTy, containerElemTy);
1035
1036 if (!foldedIntraVectorOffset) {
1037 auto resultVector = arith::ConstantOp::create(
1038 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1040 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1041 linearizedInfo.intraDataOffset, origElements);
1042 } else if (!isDivisibleInSize) {
1044 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1045 }
1046 rewriter.replaceOp(op, result);
1047 return success();
1048 }
1049};
1050
1051//===----------------------------------------------------------------------===//
1052// ConvertVectorMaskedLoad
1053//===----------------------------------------------------------------------===//
1054
1055/// Converts `vector.maskedload` operations on narrow element types to work with
1056/// wider, byte-aligned container types by adjusting the mask and using
1057/// bitcasting.
1058///
1059/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1060/// bitcasting, since each `i8` container element holds two `i4` values.
1061struct ConvertVectorMaskedLoad final
1062 : OpConversionPattern<vector::MaskedLoadOp> {
1063 using Base::Base;
1064
1065 LogicalResult
1066 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1067 ConversionPatternRewriter &rewriter) const override {
1068 if (op.getVectorType().getRank() != 1)
1069 return rewriter.notifyMatchFailure(
1070 op, "Memref in emulated vector ops must be flattened beforehand.");
1071
1072 auto loc = op.getLoc();
1073
1074 auto containerElemTy =
1075 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1076 Type emulatedElemTy = op.getType().getElementType();
1077 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1078 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1079
1080 // Check per-element alignment.
1081 if (containerBits % emulatedBits != 0) {
1082 return rewriter.notifyMatchFailure(
1083 op, "impossible to pack emulated elements into container elements "
1084 "(bit-wise misalignment)");
1085 }
1086 int emulatedPerContainerElem = containerBits / emulatedBits;
1087
1088 // Adjust the number of elements to load when emulating narrow types,
1089 // and then cast back to the original type with vector.bitcast op.
1090 // For example, to emulate i4 to i8, the following op:
1091 //
1092 // %mask = vector.constant_mask [3] : vector<6xi1>
1093 // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1094 // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1095 //
1096 // can be replaced with
1097 //
1098 // %new_mask = vector.constant_mask [2] : vector<3xi1>
1099 // %new_pass_thru = vector.bitcast %pass_thru :
1100 // vector<6xi4> to vector<3xi8>
1101 // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1102 // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1103 // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1104 //
1105 // Since we are effectively loading 16 bits (2xi8) from the memref with the
1106 // new mask, while originally we only wanted to effectively load 12 bits
1107 // (3xi4) from the memref, we need to set the second half of the last i8
1108 // that was effectively loaded (i.e. the second i8) to %pass_thru.
1109 //
1110 // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1111 //
1112 // Given these input values:
1113 // %mask = [1, 1, 1, 0, 0, 0]
1114 // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1115 // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1116 //
1117 // we'll have:
1118 //
1119 // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1120 //
1121 // %new_mask = [1, 1, 0]
1122 // %new_pass_thru = [0x78, 0x9A, 0xBC]
1123 // %1 = [0x12, 0x34, 0xBC]
1124 // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1125 // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1126 //
1127 // TODO: Currently, only the even number of elements loading is supported.
1128 // To deal with the odd number of elements, one has to extract the
1129 // subvector at the proper offset after bit-casting.
1130 auto origType = op.getVectorType();
1131 auto origElements = origType.getNumElements();
1132 // Note, per-element-alignment was already verified above.
1133 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1134
1135 auto stridedMetadata =
1136 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1137 OpFoldResult linearizedIndices;
1138 memref::LinearizedMemRefInfo linearizedInfo;
1139 std::tie(linearizedInfo, linearizedIndices) =
1141 rewriter, loc, emulatedBits, containerBits,
1142 stridedMetadata.getConstifiedMixedOffset(),
1143 stridedMetadata.getConstifiedMixedSizes(),
1144 stridedMetadata.getConstifiedMixedStrides(),
1145 getAsOpFoldResult(adaptor.getIndices()));
1146
1147 std::optional<int64_t> foldedIntraVectorOffset =
1148 isDivisibleInSize ? 0
1149 : getConstantIntValue(linearizedInfo.intraDataOffset);
1150
1151 int64_t maxIntraDataOffset =
1152 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1153 FailureOr<Operation *> newMask =
1154 getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1155 emulatedPerContainerElem, maxIntraDataOffset);
1156 if (failed(newMask))
1157 return failure();
1158
1159 Value passthru = op.getPassThru();
1160
1161 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1162 emulatedPerContainerElem);
1163 auto loadType = VectorType::get(numElements, containerElemTy);
1164 auto newBitcastType =
1165 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1166
1167 auto emptyVector = arith::ConstantOp::create(
1168 rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1169 if (!foldedIntraVectorOffset) {
1170 passthru = dynamicallyInsertSubVector(
1171 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1172 origElements);
1173 } else if (!isDivisibleInSize) {
1174 passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1175 *foldedIntraVectorOffset);
1176 }
1177 auto newPassThru =
1178 vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1179
1180 // Generating the new masked load.
1181 auto newLoad = vector::MaskedLoadOp::create(
1182 rewriter, loc, loadType, adaptor.getBase(),
1183 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1184 newMask.value()->getResult(0), newPassThru);
1185
1186 // Setting the part that originally was not effectively loaded from memory
1187 // to pass through.
1188 auto bitCast =
1189 vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1190
1191 Value mask = op.getMask();
1192 auto newSelectMaskType = VectorType::get(
1193 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1194 // TODO: try to fold if op's mask is constant
1195 auto emptyMask =
1196 arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1197 rewriter.getZeroAttr(newSelectMaskType));
1198 if (!foldedIntraVectorOffset) {
1199 mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1200 linearizedInfo.intraDataOffset,
1201 origElements);
1202 } else if (!isDivisibleInSize) {
1203 mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1204 *foldedIntraVectorOffset);
1205 }
1206
1207 Value result =
1208 arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1209 if (!foldedIntraVectorOffset) {
1211 rewriter, loc, result, op.getPassThru(),
1212 linearizedInfo.intraDataOffset, origElements);
1213 } else if (!isDivisibleInSize) {
1215 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1216 }
1217 rewriter.replaceOp(op, result);
1218
1219 return success();
1220 }
1221};
1222
1223/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1224///
1225/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1226/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1227/// (a multi-byte scalar, e.g. i16), where N is some integer.
1228///
1229/// Put differently, this method checks whether this would be valid:
1230///
1231/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1232///
1233/// EXAMPLES:
1234/// * vector<4xi4> -> i16 - yes (N = 1)
1235/// * vector<4xi4> -> i8 - yes (N = 2)
1236/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1237/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1238static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1239 Type multiByteScalarTy) {
1240 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1241
1242 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1243 int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1244
1245 assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1246 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1247 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1248
1249 int elemsPerMultiByte = multiByteBits / subByteBits;
1250
1251 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// ConvertVectorTransferRead
1256//===----------------------------------------------------------------------===//
1257
1258// TODO: Document-me
1259struct ConvertVectorTransferRead final
1260 : OpConversionPattern<vector::TransferReadOp> {
1261 using Base::Base;
1262
1263 LogicalResult
1264 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const override {
1266
1267 // Prerequisites: memref in the vector.transfer_read op is flattened into
1268 // 1-D.
1269 if (op.getVectorType().getRank() != 1)
1270 return rewriter.notifyMatchFailure(
1271 op, "Memref in emulated vector ops must be flattened beforehand.");
1272
1273 auto loc = op.getLoc();
1274 auto containerElemTy =
1275 cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1276 Type emulatedElemTy = op.getType().getElementType();
1277 int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1278 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1279
1280 // Check per-element alignment.
1281 if (containerBits % emulatedBits != 0) {
1282 return rewriter.notifyMatchFailure(
1283 op, "impossible to pack emulated elements into container elements "
1284 "(bit-wise misalignment)");
1285 }
1286 int emulatedPerContainerElem = containerBits / emulatedBits;
1287
1288 auto origElements = op.getVectorType().getNumElements();
1289
1290 // Note, per-element-alignment was already verified above.
1291 bool isDivisibleInSize =
1292 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1293
1294 // Pad the padding value with 0s on the left. These bits are discarded and
1295 // thus their values don't matter.
1296 Value padding = adaptor.getPadding();
1297 if (!padding.getType().isInteger()) {
1298 padding = arith::BitcastOp::create(
1299 rewriter, loc,
1300 IntegerType::get(rewriter.getContext(),
1301 padding.getType().getIntOrFloatBitWidth()),
1302 padding);
1303 }
1304 auto newPadding =
1305 arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1306
1307 auto stridedMetadata =
1308 memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1309
1310 OpFoldResult linearizedIndices;
1311 memref::LinearizedMemRefInfo linearizedInfo;
1312 std::tie(linearizedInfo, linearizedIndices) =
1314 rewriter, loc, emulatedBits, containerBits,
1315 stridedMetadata.getConstifiedMixedOffset(),
1316 stridedMetadata.getConstifiedMixedSizes(),
1317 stridedMetadata.getConstifiedMixedStrides(),
1318 getAsOpFoldResult(adaptor.getIndices()));
1319
1320 std::optional<int64_t> foldedIntraVectorOffset =
1321 isDivisibleInSize ? 0
1322 : getConstantIntValue(linearizedInfo.intraDataOffset);
1323
1324 int64_t maxIntraDataOffset =
1325 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1326 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1327 emulatedPerContainerElem);
1328
1329 auto newRead = vector::TransferReadOp::create(
1330 rewriter, loc, VectorType::get(numElements, containerElemTy),
1331 adaptor.getBase(),
1332 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1333 newPadding);
1334
1335 auto bitCast = vector::BitCastOp::create(
1336 rewriter, loc,
1337 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1338 newRead);
1339
1340 Value result = bitCast->getResult(0);
1341 if (!foldedIntraVectorOffset) {
1342 auto zeros = arith::ConstantOp::create(
1343 rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1344 result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1345 linearizedInfo.intraDataOffset,
1346 origElements);
1347 } else if (!isDivisibleInSize) {
1349 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1350 }
1351 rewriter.replaceOp(op, result);
1352
1353 return success();
1354 }
1355};
1356} // end anonymous namespace
1357
1358//===----------------------------------------------------------------------===//
1359// RewriteBitCastOfTruncI
1360//===----------------------------------------------------------------------===//
1361
1362namespace {
1363
1364/// Helper struct to keep track of the provenance of a contiguous set of bits
1365/// in a source vector.
1366struct SourceElementRange {
1367 /// The index of the source vector element that contributes bits to *this.
1368 int64_t sourceElementIdx;
1369 /// The range of bits in the source vector element that contribute to *this.
1370 int64_t sourceBitBegin;
1371 int64_t sourceBitEnd;
1372};
1373
1374struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1375 /// Given the index of a SourceElementRange in the SourceElementRangeList,
1376 /// compute the amount of bits that need to be shifted to the left to get the
1377 /// bits in their final location. This shift amount is simply the sum of the
1378 /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1379 /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1380 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1381 int64_t res = 0;
1382 for (int64_t i = 0; i < shuffleIdx; ++i)
1383 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1384 return res;
1385 }
1386};
1387
1388/// Helper struct to enumerate the source elements and bit ranges that are
1389/// involved in a bitcast operation.
1390/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1391/// any 1-D vector shape and any source/target bitwidths.
1392/// This creates and holds a mapping of the form:
1393/// [dstVectorElementJ] ==
1394/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1395/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1396/// [0] = {0, [0-8)}
1397/// [1] = {0, [8-16)}
1398/// [2] = {0, [16-24)}
1399/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1400/// [0] = {0, [0, 10)}, {1, [0, 5)}
1401/// [1] = {1, [5, 10)}, {2, [0, 10)}
1402struct BitCastBitsEnumerator {
1403 BitCastBitsEnumerator(VectorType sourceVectorType,
1404 VectorType targetVectorType);
1405
1406 int64_t getMaxNumberOfEntries() {
1407 int64_t numVectors = 0;
1408 for (const auto &l : sourceElementRanges)
1409 numVectors = std::max(numVectors, (int64_t)l.size());
1410 return numVectors;
1411 }
1412
1413 VectorType sourceVectorType;
1414 VectorType targetVectorType;
1415 SmallVector<SourceElementRangeList> sourceElementRanges;
1416};
1417
1418/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1419/// advantage of high-level information to avoid leaving LLVM to scramble with
1420/// peephole optimizations.
1421/// BitCastBitsEnumerator encodes for each element of the target vector the
1422/// provenance of the bits in the source vector. We can "transpose" this
1423/// information to build a sequence of shuffles and bitwise ops that will
1424/// produce the desired result.
1425//
1426/// Consider the following motivating example:
1427/// ```
1428/// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1429/// ```
1430//
1431/// BitCastBitsEnumerator contains the following information:
1432/// ```
1433/// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1434/// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1435/// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1436/// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1437/// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1438/// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1439/// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1440/// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1441/// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1442/// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1443/// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1444/// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1445/// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1446/// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1447/// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1448/// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1449/// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1450/// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1451/// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1452/// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1453/// ```
1454///
1455/// In the above, each row represents one target vector element and each
1456/// column represents one bit contribution from a source vector element.
1457/// The algorithm creates vector.shuffle operations (in this case there are 3
1458/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1459/// algorithm populates the bits as follows:
1460/// ```
1461/// src bits 0 ...
1462/// 1st shuffle |xxxxx |xx |...
1463/// 2nd shuffle | xxx| xxxxx |...
1464/// 3rd shuffle | | x|...
1465/// ```
1466//
1467/// The algorithm proceeds as follows:
1468/// 1. for each vector.shuffle, collect the source vectors that participate in
1469/// this shuffle. One source vector per target element of the resulting
1470/// vector.shuffle. If there is no source element contributing bits for the
1471/// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1472/// 2 columns).
1473/// 2. represent the bitrange in the source vector as a mask. If there is no
1474/// source element contributing bits for the current vector.shuffle, take 0.
1475/// 3. shift right by the proper amount to align the source bitrange at
1476/// position 0. This is exactly the low end of the bitrange. For instance,
1477/// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1478/// shift right by 3 to get the bits contributed by the source element #1
1479/// into position 0.
1480/// 4. shift left by the proper amount to to align to the desired position in
1481/// the result element vector. For instance, the contribution of the second
1482/// source element for the first row needs to be shifted by `5` to form the
1483/// first i8 result element.
1484///
1485/// Eventually, we end up building the sequence
1486/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1487/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1488/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1489struct BitCastRewriter {
1490 /// Helper metadata struct to hold the static quantities for the rewrite.
1491 struct Metadata {
1492 SmallVector<int64_t> shuffles;
1493 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1494 };
1495
1496 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1497
1498 /// Verify that general preconditions for the rewrite are met.
1499 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1500 VectorType preconditionType, Operation *op);
1501
1502 /// Precompute the metadata for the rewrite.
1503 SmallVector<BitCastRewriter::Metadata>
1504 precomputeMetadata(IntegerType shuffledElementType);
1505
1506 /// Rewrite one step of the sequence:
1507 /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1508 Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1509 Value initialValue, Value runningResult,
1510 const BitCastRewriter::Metadata &metadata);
1511
1512private:
1513 /// Underlying enumerator that encodes the provenance of the bits in the each
1514 /// element of the result vector.
1515 BitCastBitsEnumerator enumerator;
1516};
1517
1518} // namespace
1519
1520[[maybe_unused]] static raw_ostream &
1522 for (const auto &l : vec) {
1523 for (auto it : llvm::enumerate(l)) {
1524 os << "{ " << it.value().sourceElementIdx << ": b@["
1525 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1526 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1527 }
1528 os << "\n";
1529 }
1530 return os;
1531}
1532
1533BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1534 VectorType targetVectorType)
1535 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1536
1537 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1538 "requires -D non-scalable vector type");
1539 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1540 "requires -D non-scalable vector type");
1541 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1542 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1543 LDBG() << "sourceVectorType: " << sourceVectorType;
1544
1545 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1546 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1547 LDBG() << "targetVectorType: " << targetVectorType;
1548
1549 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1550 (void)mostMinorSourceDim;
1551 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1552 "source and target bitwidths must match");
1553
1554 // Prepopulate one source element range per target element.
1555 sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1556 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1557 int64_t resultElement = resultBit / targetBitWidth;
1558 int64_t resultBitInElement = resultBit % targetBitWidth;
1559 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1560 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1561 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1562 targetBitWidth - resultBitInElement);
1563 sourceElementRanges[resultElement].push_back(
1564 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1565 resultBit += step;
1566 }
1567}
1568
1569BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1570 VectorType targetVectorType)
1571 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1572 LDBG() << "\n" << enumerator.sourceElementRanges;
1573}
1574
1575/// Verify that the precondition type meets the common preconditions for any
1576/// conversion.
1577static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1578 VectorType preconditionType,
1579 Operation *op) {
1580 if (!preconditionType || preconditionType.isScalable())
1581 return rewriter.notifyMatchFailure(op, "scalable vector");
1582
1583 // TODO: consider relaxing this restriction in the future if we find ways
1584 // to really work with subbyte elements across the MLIR/LLVM boundary.
1585 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1586 if (bitwidth % 8 != 0)
1587 return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1588
1589 return success();
1590}
1591
1592LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1593 VectorType preconditionType,
1594 Operation *op) {
1595 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1596 return rewriter.notifyMatchFailure(op, "types are not vector");
1597
1598 if (!preconditionType || preconditionType.getRank() != 1)
1599 return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1600
1601 return commonConversionPrecondition(rewriter, preconditionType, op);
1602}
1603
1604/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1605///
1606/// Alignment means that `subByteVecTy` can be packed into a vector of
1607/// `containerTy` elements. More specifically:
1608/// 1. The bit-width of `containerTy` is a multiple of the
1609/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1610/// this multiple is 4.
1611/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1612/// elements in `subByteVecTy`.
1613///
1614/// EXAMPLE 1:
1615/// `subByteVecTy = vector<2xi4>`, and
1616/// `containerTy = i16`
1617///
1618/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1619///
1620/// EXAMPLE 2:
1621/// `subByteVecTy = vector<3xi4>`, and
1622/// `containerTy = i16`
1623///
1624/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1625///
1626/// EXAMPLE 3:
1627/// `subByteVecTy = vector<3xi3>`, and
1628/// `containerTy = i16`
1629///
1630/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1631///
1632/// NOTE: This method assumes that common conversion preconditions are met. In
1633/// particular, `containerTy` is assumed to be a
1634/// multi-byte scalar type (e.g., i8, i16, i32).
1635static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1636 VectorType subByteVecTy,
1637 Type containerTy,
1638 Operation *op) {
1639 assert(containerTy.isIntOrFloat() &&
1640 "container element type is not a scalar");
1641
1642 // TODO: This is validating the inputs rather than checking the conditions
1643 // documented above. Replace with an assert.
1644 if (!subByteVecTy)
1645 return rewriter.notifyMatchFailure(op, "not a vector!");
1646
1647 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1648 unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1649
1650 // Enforced by the common pre-conditions.
1651 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1652
1653 // TODO: Add support other widths (when/if needed)
1654 if (subByteBits != 2 && subByteBits != 4)
1655 return rewriter.notifyMatchFailure(
1656 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1657
1658 // Condition 1 ("per-element" alignment)
1659 if (containerBits % subByteBits != 0)
1660 return rewriter.notifyMatchFailure(op, "unalagined element types");
1661
1662 // Condition 2 ("full" alignment)
1663 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1664 return rewriter.notifyMatchFailure(
1665 op, "not possible to fit this sub-byte vector type into a vector of "
1666 "the given multi-byte type");
1667
1668 return success();
1669}
1670
1671SmallVector<BitCastRewriter::Metadata>
1672BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1673 SmallVector<BitCastRewriter::Metadata> result;
1674 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1675 shuffleIdx < e; ++shuffleIdx) {
1676 SmallVector<int64_t> shuffles;
1677 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1678
1679 // Create the attribute quantities for the shuffle / mask / shift ops.
1680 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1681 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1682 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1683 : 0;
1684 shuffles.push_back(sourceElement);
1685
1686 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1687 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1688 : 0;
1689 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1690 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1691 : 0;
1692 IntegerAttr mask = IntegerAttr::get(
1693 shuffledElementType,
1694 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1695 bitLo, bitHi));
1696 masks.push_back(mask);
1697
1698 int64_t shiftRight = bitLo;
1699 shiftRightAmounts.push_back(
1700 IntegerAttr::get(shuffledElementType, shiftRight));
1701
1702 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1703 shiftLeftAmounts.push_back(
1704 IntegerAttr::get(shuffledElementType, shiftLeft));
1705 }
1706
1707 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1708 }
1709 return result;
1710}
1711
1712Value BitCastRewriter::genericRewriteStep(
1713 PatternRewriter &rewriter, Location loc, Value initialValue,
1714 Value runningResult, const BitCastRewriter::Metadata &metadata) {
1715 // Create vector.shuffle from the metadata.
1716 auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1717 initialValue, metadata.shuffles);
1718
1719 // Intersect with the mask.
1720 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1721 auto constOp = arith::ConstantOp::create(
1722 rewriter, loc,
1723 DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1724 Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1725
1726 // Align right on 0.
1727 auto shiftRightConstantOp = arith::ConstantOp::create(
1728 rewriter, loc,
1729 DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1730 Value shiftedRight =
1731 arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1732
1733 // Shift bits left into their final position.
1734 auto shiftLeftConstantOp = arith::ConstantOp::create(
1735 rewriter, loc,
1736 DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1737 Value shiftedLeft =
1738 arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1739
1740 runningResult =
1741 runningResult
1742 ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1743 : shiftedLeft;
1744
1745 return runningResult;
1746}
1747
1748/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1749/// Where aligned means it satisfies the alignedConversionPreconditions.
1750///
1751/// Example:
1752/// vector<16x16xi2> -> vector<16x4xi8>
1753/// vector<16x16xi4> -> vector<16x8xi8>
1755 Value subByteVec) {
1756 auto srcVecType = cast<VectorType>(subByteVec.getType());
1757 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1758 assert(8 % srcBitwidth == 0 &&
1759 "Unsupported sub-byte type (not a divisor of i8)");
1760 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1761 SmallVector<int64_t> vecShape(srcVecType.getShape());
1762 // Adjust last dimension of the vector, so the total size remains the same.
1763 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1764 auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1765 return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1766}
1767
1768/// Extracts a signed N-bit sequence from each element of a vector of bytes,
1769/// starting at the specified bit index.
1770/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1771///
1772/// Example for a single element:
1773/// Extract numBits=2 starting at bitIdx=2
1774/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1775/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1776/// target = [. . . . ^ ^ . .]
1777///
1778/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1779/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1780///
1781/// src = [01 01 11 10]
1782/// shl = arith.shl(src, 4) -> [11 10 00 00]
1783/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1785 Location loc, Value src,
1786 int bitIdx, int numBits) {
1787 auto srcType = cast<VectorType>(src.getType());
1788 Value shl = src;
1789 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1790 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1791 "Invalid bitIdx range");
1792 if (bitsToShiftLeft != 0) {
1793 Value shiftLeftValues = arith::ConstantOp::create(
1794 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1795 shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1796 }
1797
1798 int8_t bitsToShiftRight = 8 - numBits;
1799 Value shiftRightValues = arith::ConstantOp::create(
1800 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1801 Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1802 return shr;
1803}
1804
1805/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1806/// starting at the specified bit index.
1807/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1808///
1809/// Example for a single element:
1810/// Extract numBits=2 starting at bitIdx=2
1811/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1812/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1813/// target = [. . . . ^ ^ . .]
1814///
1815/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1816/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1817///
1818/// src = [01 01 10 10]
1819/// mask = [00 00 00 11]
1820/// shr = arith.shrui(src, 2) = [00 01 01 10]
1821/// result = arith.andi(shr, mask) = [00 00 00 10]
1822/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1823/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1824/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1825/// left when the index is 0.
1827 Location loc, Value src,
1828 int bitIdx, int numBits) {
1829 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1830 "Invalid bitIdx range");
1831 auto srcType = cast<VectorType>(src.getType());
1832 int8_t bitsToShiftRight = bitIdx;
1833 Value shr = src;
1834 if (bitsToShiftRight != 0) {
1835 Value shiftRightValues = arith::ConstantOp::create(
1836 rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1837 shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1838 }
1839 if (bitIdx + numBits == 8) {
1840 return shr;
1841 }
1842 uint8_t lowBitsMask = (1 << numBits) - 1;
1843 Value lowBitsMaskValues = arith::ConstantOp::create(
1844 rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1845 return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1846}
1847
1849 std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1850
1851/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1852/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1854 Value srcValue, const ExtractNBitsFn &extFn) {
1855 [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1856 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1857 "Expected i4 type");
1858
1859 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1860 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1861
1862 // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1863 // byte are place in one vector and the high i4 elements in another vector.
1864 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1865 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1866
1867 // 3. Interleave low and high i8 elements.
1868 return vector::InterleaveOp::create(rewriter, loc, low, high);
1869}
1870
1871/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1872/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1874 Value srcValue, const ExtractNBitsFn &extFn) {
1875 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1876 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1877 "Expected i2 type");
1878
1879 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1880 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1881
1882 // 2. Extract each i2 element
1883 // Positon 0 (bits 0-1)
1884 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1885 // Position 1 (bits 2-3)
1886 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1887 // Position 2 (bits 4-5)
1888 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1889 // Position 3 (bits 6-7)
1890 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1891
1892 // 3. Interleave all 4 elements by first interleaving
1893 // even elements and then odd
1894 // vec0 = [0,0,0,0],...
1895 // vec1 = [1,1,1,1],...
1896 // vec2 = [2,2,2,2],...
1897 // vec3 = [3,3,3,3],...
1898 // 02 = [0,2,0,2,0,2,0,2],...
1899 // 13 = [1,3,1,3,1,3,1,3],...
1900 // 0213 = [0,1,2,3,...],...
1901 Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1902 Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1903 return vector::InterleaveOp::create(rewriter, loc, interleave02,
1904 interleave13);
1905}
1906
1907/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1908/// ops to avoid leaving LLVM to scramble with peephole optimizations.
1910 Value srcValue) {
1911 VectorType srcVecType = cast<VectorType>(srcValue.getType());
1912 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1913 "Expected i8 type");
1914
1915 // 1. De-interleave low and high i8 elements.
1916 auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1917
1918 // 2. Zero out the upper side of each low i8 element.
1919 constexpr int8_t i8LowBitMask = 0x0F;
1920 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1921 Value zeroOutMask = arith::ConstantOp::create(
1922 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1923 Value zeroOutLow = arith::AndIOp::create(
1924 rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1925
1926 // 3. Move high i4 values to upper side of the byte.
1927 constexpr int8_t bitsToShift = 4;
1928 auto shiftValues = arith::ConstantOp::create(
1929 rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1930 Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1931 shiftValues);
1932
1933 // 4. Merge high and low i4 values.
1934 auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1935
1936 // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1937 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1938 return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1939}
1940
1941namespace {
1942/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1943/// advantage of high-level information to avoid leaving LLVM to scramble with
1944/// peephole optimizations.
1945struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1946 using Base::Base;
1947
1948 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1949 PatternRewriter &rewriter) const override {
1950 // The source must be a trunc op.
1951 auto truncOp =
1952 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1953 if (!truncOp)
1954 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1955
1956 // Set up the BitCastRewriter and verify the precondition.
1957 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1958 VectorType targetVectorType = bitCastOp.getResultVectorType();
1959 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1960 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1961 return failure();
1962
1963 // Perform the rewrite.
1964 Value truncValue = truncOp.getIn();
1965 auto shuffledElementType =
1966 cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1967 Value runningResult;
1968 for (const BitCastRewriter ::Metadata &metadata :
1969 bcr.precomputeMetadata(shuffledElementType)) {
1970 runningResult = bcr.genericRewriteStep(
1971 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1972 }
1973
1974 // Finalize the rewrite.
1975 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1976 shuffledElementType.getIntOrFloatBitWidth();
1977 if (narrowing) {
1978 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1979 rewriter.replaceOp(bitCastOp, runningResult);
1980 } else {
1981 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1982 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1983 }
1984 } else {
1985 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1986 rewriter.replaceOp(bitCastOp, runningResult);
1987 } else {
1988 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1989 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1990 }
1991 }
1992
1993 return success();
1994 }
1995};
1996} // namespace
1997
1998//===----------------------------------------------------------------------===//
1999// RewriteExtOfBitCast
2000//===----------------------------------------------------------------------===//
2001
2002namespace {
2003/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
2004/// take advantage of high-level information to avoid leaving LLVM to scramble
2005/// with peephole optimizations.
2006template <typename ExtOpType>
2007struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2008 using OpRewritePattern<ExtOpType>::OpRewritePattern;
2009
2010 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2011 : OpRewritePattern<ExtOpType>(context, benefit) {}
2012
2013 LogicalResult matchAndRewrite(ExtOpType extOp,
2014 PatternRewriter &rewriter) const override {
2015 // The source must be a bitcast op.
2016 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2017 if (!bitCastOp)
2018 return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2019
2020 // Set up the BitCastRewriter and verify the precondition.
2021 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2022 VectorType targetVectorType = bitCastOp.getResultVectorType();
2023 BitCastRewriter bcr(sourceVectorType, targetVectorType);
2024 if (failed(bcr.commonPrecondition(
2025 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2026 return failure();
2027
2028 // Perform the rewrite.
2029 Value runningResult;
2030 Value sourceValue = bitCastOp.getSource();
2031 auto shuffledElementType =
2032 cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2033 for (const BitCastRewriter::Metadata &metadata :
2034 bcr.precomputeMetadata(shuffledElementType)) {
2035 runningResult = bcr.genericRewriteStep(
2036 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2037 }
2038
2039 // Finalize the rewrite.
2040 bool narrowing =
2041 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2042 shuffledElementType.getIntOrFloatBitWidth();
2043 if (narrowing) {
2044 rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2045 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2046 } else {
2047 rewriter.replaceOpWithNewOp<ExtOpType>(
2048 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2049 }
2050
2051 return success();
2052 }
2053};
2054
2055/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2056/// bitwise ops that take advantage of high-level information to avoid leaving
2057/// LLVM to scramble with peephole optimizations. Templated to choose between
2058/// signed and unsigned conversions.
2059///
2060/// EXAMPLE 1 (signed):
2061/// arith.extsi %in : vector<8xi4> to vector<8xi32>
2062/// is rewriten as:
2063/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2064/// %1 = arith.shli %0, 4 : vector<4xi8>
2065/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2066/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2067/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2068/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2069///
2070/// EXAMPLE 2 (fp):
2071/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2072/// is rewriten as:
2073/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2074/// %1 = arith.shli %0, 4 : vector<4xi8>
2075/// %2 = arith.shrsi %1, 4 : vector<4xi8>
2076/// %3 = arith.shrsi %0, 4 : vector<4xi8>
2077/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2078/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2079///
2080/// EXAMPLE 3 (unsigned):
2081/// arith.extui %in : vector<8xi4> to vector<8xi32>
2082/// is rewritten as:
2083/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2084/// %1 = arith.andi %0, 15 : vector<4xi8>
2085/// %2 = arith.shrui %0, 4 : vector<4xi8>
2086/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2087/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2088///
2089template <typename ConversionOpType, bool isSigned>
2090struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2091 using OpRewritePattern<ConversionOpType>::OpRewritePattern;
2092
2093 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2094 PatternRewriter &rewriter) const override {
2095 // Verify the preconditions.
2096 Value srcValue = conversionOp.getIn();
2097 VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2098 VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2099
2100 if (failed(
2101 commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2102 return failure();
2103
2104 // Check general alignment preconditions.
2106 rewriter, srcVecType,
2107 /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2108 return failure();
2109
2110 // Perform the rewrite.
2111 Location loc = conversionOp.getLoc();
2112 const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2114 Value subByteExt;
2115 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2116 case 2:
2117 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2118 break;
2119 case 4:
2120 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2121 break;
2122 default:
2123 return failure();
2124 }
2125
2126 // Finalize the rewrite.
2127 rewriter.replaceOpWithNewOp<ConversionOpType>(
2128 conversionOp, conversionOp.getType(), subByteExt);
2129 return success();
2130 }
2131};
2132
2133/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2134/// bitwise ops that take advantage of high-level information to avoid leaving
2135/// LLVM to scramble with peephole optimizations.
2136///
2137/// For example:
2138/// arith.trunci %in : vector<8xi32> to vector<8xi4>
2139///
2140/// is rewriten as:
2141///
2142/// %cst = arith.constant dense<15> : vector<4xi8>
2143/// %cst_0 = arith.constant dense<4> : vector<4xi8>
2144/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2145/// %2 = arith.andi %0, %cst : vector<4xi8>
2146/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2147/// %4 = arith.ori %2, %3 : vector<4xi8>
2148/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2149///
2150struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2151 using Base::Base;
2152
2153 LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2154 PatternRewriter &rewriter) const override {
2155 // Verify the preconditions.
2156 Value srcValue = truncOp.getIn();
2157 auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2158 auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2159 if (!srcVecType || !dstVecType)
2160 return failure();
2161
2162 if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2163 return failure();
2164
2165 // TODO: Add support for truncating to i2.
2166 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2167 return failure();
2168
2169 // Check general alignment preconditions. We invert the src/dst type order
2170 // to reuse the existing precondition logic.
2172 rewriter, dstVecType,
2173 /*containerTy=*/rewriter.getI8Type(), truncOp)))
2174 return failure();
2175
2176 // Create a new iX -> i8 truncation op.
2177 Location loc = truncOp.getLoc();
2178 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2179 Value i8TruncVal =
2180 arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2181
2182 // Rewrite the i8 -> i4 truncation part.
2183 Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2184
2185 // Finalize the rewrite.
2186 rewriter.replaceOp(truncOp, subByteTrunc);
2187 return success();
2188 }
2189};
2190
2191/// Rewrite a sub-byte vector transpose into a sequence of instructions that
2192/// perform the transpose on wider (byte) element types.
2193///
2194/// EXAMPLE:
2195/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2196///
2197/// is rewritten as:
2198///
2199/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2200/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2201/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2202///
2203struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2204 using Base::Base;
2205
2206 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2207 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2208
2209 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2210 PatternRewriter &rewriter) const override {
2211 // Precondition: sub-byte integer transpose.
2212 constexpr unsigned minNativeBitwidth = 8;
2213 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2214 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2215 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2216 return rewriter.notifyMatchFailure(transposeOp,
2217 "not a sub-byte transpose");
2218 }
2219
2220 // Perform the rewrite.
2221 Location loc = transposeOp.getLoc();
2222 // Signed/unsigned interpretation shouldn't matter here as we are just
2223 // transposing the elements and truncating them back to the original size.
2224 // TODO: Use unsigned extension (more efficient) when emulation or backend
2225 // support is available.
2226 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2227 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2228 Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2229 transposeOp.getVector());
2230 Value newTranspose = vector::TransposeOp::create(
2231 rewriter, loc, extOp, transposeOp.getPermutation());
2232 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2233 rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2234 newTranspose);
2235 return success();
2236 }
2237};
2238
2239} // namespace
2240
2241//===----------------------------------------------------------------------===//
2242// Public Interface Definition
2243//===----------------------------------------------------------------------===//
2244
2245// The emulated type is inferred from the converted memref type.
2246void vector::populateVectorNarrowTypeEmulationPatterns(
2247 const arith::NarrowTypeEmulationConverter &typeConverter,
2248 RewritePatternSet &patterns, bool disableAtomicRMW) {
2249 // Populate `vector.*` conversion patterns.
2250 // TODO: #119553 support atomicity
2251 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2252 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2253 typeConverter, patterns.getContext());
2254
2255 // Populate `vector.*` store conversion patterns. The caller can choose
2256 // to avoid emitting atomic operations and reduce it to read-modify-write
2257 // sequence for stores if it is known there are no thread contentions.
2258 patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2259}
2260
2261void vector::populateVectorNarrowTypeRewritePatterns(
2262 RewritePatternSet &patterns, PatternBenefit benefit) {
2263 // TODO: Document what the emulated type is.
2264 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2265 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2266 benefit);
2267
2268 // Patterns for aligned cases. We set higher priority as they are expected to
2269 // generate better performance for aligned cases.
2270 // The container type is always i8.
2271 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2272 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2273 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2274 benefit.getBenefit() + 1);
2275 // The container type is always i8.
2276 patterns
2277 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2278 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2279 patterns.getContext(), benefit.getBenefit() + 1);
2280}
2281
2282// The container type is always i8.
2283void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2284 RewritePatternSet &patterns, PatternBenefit benefit) {
2285 patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2286}
2287
2288void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2289 arith::NarrowTypeEmulationConverter &typeConverter,
2290 RewritePatternSet &patterns) {
2292 vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2293}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IntegerType getI4Type()
Definition Builders.cpp:57
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
IntegerType getI8Type()
Definition Builders.cpp:59
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_type_range getResultTypes()
Definition Operation.h:428
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.