MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
33#include "mlir/IR/Types.h"
34
35#include "llvm/ADT/TypeSwitch.h"
36
37#include <numeric>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48// TODO: Below are uArch dependent values, should move away from hardcoding
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
51
52// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
53enum class NdTdescOffset : uint32_t {
54 BasePtr = 0, // Base pointer (i64)
55 BaseShapeW = 2, // Base shape width (i32)
56 BaseShapeH = 3, // Base shape height (i32)
57 BasePitch = 4, // Base pitch/stride of dim rank-2 (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70/// Checks if the given MemRefType refers to shared memory.
71static bool isSharedMemRef(const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
73 if (!attr)
74 return false;
75 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
77 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
80}
81
82// Get same bitwidth flat vector type of new element type.
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 Type toElemType) {
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
87 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
88 const int size =
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
91}
92
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 // If no hints are provided, use the default cache control.
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
99 // If only one of the hints is provided, use the default for the other level.
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
102 switch (L1hintVal) {
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
108 else
109 llvm_unreachable("Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
115 else
116 llvm_unreachable("Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
122 else
123 llvm_unreachable("Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
126 default:
127 llvm_unreachable("Unsupported cache control.");
128 }
129}
130
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
134 // If no hints are provided, use the default cache control.
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
137 // If only one of the hints is provided, use the default for the other level.
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
140 switch (L1hintVal) {
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
146 else
147 llvm_unreachable("Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
153 else
154 llvm_unreachable("Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
160 else
161 llvm_unreachable("Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
167 else
168 llvm_unreachable("Unsupported cache control.");
169 default:
170 llvm_unreachable("Unsupported cache control.");
171 }
172}
173
174//
175// Note:
176// Block operations for tile of sub byte element types are handled by
177// emulating with larger element types.
178// Tensor descriptor are keep intact and only ops consuming them are
179// emulated
180//
181
182class CreateNdDescToXeVMPattern
183 : public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
185 LogicalResult
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
191 // Op is lowered to a code sequence that populates payload.
192 // Payload is a 8xi32 vector. Offset to individual fields are defined in
193 // NdTdescOffset enum.
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
197 // 4xi64 view is used for inserting the base pointer.
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
199 // Initialize payload to zero.
200 Value payload = arith::ConstantOp::create(
201 rewriter, loc,
202 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
203
204 Value baseAddr;
205 Value baseShapeW;
206 Value baseShapeH;
207
208 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
211 // Descriptor shape is expected to be 2D.
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
215 // If source is a memref, we need to extract the aligned pointer as index.
216 // Pointer type is passed as i32 or i64 by type converter.
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
220 }
221 // Access adaptor after failure check to avoid rolling back generated code
222 // for materialization cast.
223 baseAddr = adaptor.getSource();
224 } else {
225 baseAddr = adaptor.getSource();
226 if (baseAddr.getType() != i64Ty) {
227 // Pointer type may be i32. Cast to i64 if needed.
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
229 }
230 }
231 // 1D tensor descriptor is just the base address.
232 if (rank == 1) {
233 rewriter.replaceOp(op, baseAddr);
234 return success();
235 }
236 // Utility for creating offset values from op fold result.
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
239 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
240 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
241 return val;
242 };
243 // For ND descriptors, the last 2 dimensions are the 2D tile (H, W).
244 // Any leading dimensions are batch dims with associated strides.
245 baseShapeW = createOffset(mixedSizes, rank - 1);
246 baseShapeH = createOffset(mixedSizes, rank - 2);
247 // Pitch is the stride of dim rank-2 (the row stride of the 2D tile).
248 Value basePitch = createOffset(mixedStrides, rank - 2);
249 // Populate payload.
250 Value payLoadAsI64 =
251 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
252 payLoadAsI64 =
253 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
254 static_cast<int>(NdTdescOffset::BasePtr));
255 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
256 payload =
257 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
258 static_cast<int>(NdTdescOffset::BaseShapeW));
259 payload =
260 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
261 static_cast<int>(NdTdescOffset::BaseShapeH));
262 payload =
263 vector::InsertOp::create(rewriter, loc, basePitch, payload,
264 static_cast<int>(NdTdescOffset::BasePitch));
265 rewriter.replaceOp(op, payload);
266 return success();
267 }
268};
269
270template <
271 typename OpType,
272 typename = std::enable_if_t<llvm::is_one_of<
273 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
274class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
275 using OpConversionPattern<OpType>::OpConversionPattern;
276 LogicalResult
277 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 auto mixedOffsets = op.getMixedOffsets();
280 int64_t opOffsetsSize = mixedOffsets.size();
281 auto loc = op.getLoc();
282 auto ctxt = rewriter.getContext();
283
284 auto tdesc = adaptor.getTensorDesc();
285 auto tdescTy = op.getTensorDescType();
286 auto tileRank = tdescTy.getRank();
287 if (opOffsetsSize != tileRank)
288 return rewriter.notifyMatchFailure(
289 op, "Expected offset rank to match descriptor rank.");
290 auto elemType = tdescTy.getElementType();
291 auto elemBitSize = elemType.getIntOrFloatBitWidth();
292 bool isSubByte = elemBitSize < 8;
293 uint64_t wScaleFactor = 1;
294
295 if (!isSubByte && (elemBitSize % 8 != 0))
296 return rewriter.notifyMatchFailure(
297 op, "Expected element type bit width to be multiple of 8.");
298 auto tileW = tdescTy.getDimSize(tileRank - 1);
299 // For sub byte types, only 4bits are currently supported.
300 if (isSubByte) {
301 if (elemBitSize != 4)
302 return rewriter.notifyMatchFailure(
303 op, "Only sub byte types of 4bits are supported.");
304 if (tileRank != 2)
305 return rewriter.notifyMatchFailure(
306 op, "Sub byte types are only supported for 2D tensor descriptors.");
307 auto subByteFactor = 8 / elemBitSize;
308 auto tileH = tdescTy.getDimSize(0);
309 // Handle special case for packed load.
310 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
311 if (op.getPacked().value_or(false)) {
312 // packed load is implemented as packed loads of 8bit elements.
313 if (tileH == systolicDepth * 4 &&
314 tileW == executionSize * subByteFactor) {
315 // Usage case for loading as Matrix B with pack request.
316 // source is assumed to pre-packed into 8bit elements
317 // Emulate with 8bit loads with pack request.
318 // scaled_tileW = executionSize
319 elemType = rewriter.getIntegerType(8);
320 tileW = executionSize;
321 wScaleFactor = subByteFactor;
322 }
323 }
324 }
325 // If not handled by packed load case above, handle other cases.
326 if (wScaleFactor == 1) {
327 auto sub16BitFactor = subByteFactor * 2;
328 if (tileW == executionSize * sub16BitFactor) {
329 // Usage case for loading as Matrix A operand
330 // Emulate with 16bit loads/stores.
331 // scaled_tileW = executionSize
332 elemType = rewriter.getIntegerType(16);
333 tileW = executionSize;
334 wScaleFactor = sub16BitFactor;
335 } else {
336 return rewriter.notifyMatchFailure(
337 op, "Unsupported tile shape for sub byte types.");
338 }
339 }
340 // recompute element bit size for emulation.
341 elemBitSize = elemType.getIntOrFloatBitWidth();
342 }
343
344 // Get address space from tensor descriptor memory space.
345 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
346 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
347 if (tileRank >= 2) {
348 // Compute element byte size.
349 Value elemByteSize = arith::ConstantIntOp::create(
350 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
351 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
352 Value payLoadAsI64 =
353 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
354 Value basePtr =
355 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
356 static_cast<int>(NdTdescOffset::BasePtr));
357 Value baseShapeW = vector::ExtractOp::create(
358 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
359 Value baseShapeH = vector::ExtractOp::create(
360 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
361 Value basePitch = vector::ExtractOp::create(
362 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
363
364 // For rank > 2, leading (batch) dim offsets should be 0 after unrolling
365 // (batch is baked into the base pointer via memref.subview during
366 // blocking). Use only the last 2 offsets for the 2D block operation.
367 Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc,
368 mixedOffsets[tileRank - 1]);
369 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
370 rewriter.getI32Type(), offsetW);
371 Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc,
372 mixedOffsets[tileRank - 2]);
373 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
374 rewriter.getI32Type(), offsetH);
375 // Convert base pointer (i64) to LLVM pointer type.
376 Value basePtrLLVM =
377 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
378 // FIXME: width or pitch is not the same as baseShapeW it should be the
379 // stride of the second to last dimension in row major layout.
380 // Compute width in bytes.
381 Value baseShapeWInBytes =
382 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
383 // Compute pitch in bytes.
384 Value basePitchBytes =
385 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
386
387 if (wScaleFactor > 1) {
388 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
389 // Note: tileW is already scaled above.
390 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
391 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
392 baseShapeWInBytes = arith::ShRSIOp::create(
393 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
394 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
395 wScaleFactorValLog2);
396 offsetW =
397 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
398 }
399 // Get tile height from the tensor descriptor type (second-to-last dim).
400 auto tileH = tdescTy.getDimSize(tileRank - 2);
401 // Get vblocks from the tensor descriptor type.
402 int32_t vblocks = tdescTy.getArrayLength();
403 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
404 Value src = adaptor.getValue();
405 // If store value is a scalar, get value from op instead of adaptor.
406 // Adaptor might have optimized away single element vector
407 if (src.getType().isIntOrFloat()) {
408 src = op.getValue();
409 }
410 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
411 if (!srcVecTy)
412 return rewriter.notifyMatchFailure(
413 op, "Expected store value to be a vector type.");
414 // Get flat vector type of integer type with matching element bit size.
415 VectorType newSrcVecTy =
416 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
417 if (srcVecTy != newSrcVecTy)
418 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
419 auto storeCacheControl =
420 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 xevm::BlockStore2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
424 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
425 rewriter.eraseOp(op);
426 } else {
427 auto loadCacheControl =
428 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
429 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
430 xevm::BlockPrefetch2dOp::create(
431 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
432 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
433 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
434 rewriter.eraseOp(op);
435 } else {
436 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
437 bool vnni = op.getPacked().value_or(false);
438 auto transposeValue = op.getTranspose();
439 bool transpose =
440 transposeValue.has_value() && transposeValue.value()[0] == 1;
441 // Handle special case of 32x16 and 8bit element load
442 // with no vnni, no transpose, no vblocks.
443 // For this special case, vnni and non vnni yields the same output
444 // and only the vnni variant is supported by HW.
445 // Check and set vnni of the special case.
446 if (elemBitSize == 8 && tileW == 16 && tileH == 32 && !vnni &&
447 !transpose) {
448 vnni = true;
449 }
450 // Handle tranpose request on small element size
451 // Transpose needs to be requested on 32bit element type.
452 // offsetW and tileW needs to be adjusted to account for element type
453 // change.
454 if (transpose && elemBitSize < 32) {
455 int32_t scale = 32 / elemBitSize;
456 Value scaleLog2 = arith::ConstantIntOp::create(
457 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(scale));
458 offsetW = arith::ShRSIOp::create(rewriter, loc, offsetW, scaleLog2);
459 tileW = tileW * elemBitSize / 32;
460 elemBitSize = 32;
461 }
462 VectorType loadedTy = encodeVectorTypeTo(
463 dstVecTy, vnni ? rewriter.getI32Type()
464 : rewriter.getIntegerType(elemBitSize));
465
466 Value resultFlatVec = xevm::BlockLoad2dOp::create(
467 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
468 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
469 tileH, vblocks, transpose, vnni,
470 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
471 resultFlatVec = vector::BitCastOp::create(
472 rewriter, loc,
473 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
474 resultFlatVec);
475 rewriter.replaceOp(op, resultFlatVec);
476 }
477 }
478 } else {
479 // 1D tensor descriptor.
480 // `tdesc` represents base address as i64
481 // Offset in number of elements, need to multiply by element byte size.
482 // Compute byte offset.
483 // byteOffset = offset * elementByteSize
484 Value offset =
485 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
486 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
487 rewriter.getI64Type(), offset);
488 // Compute element byte size.
489 Value elemByteSize = arith::ConstantIntOp::create(
490 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
491 Value byteOffset =
492 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
493 // Final address = basePtr + byteOffset
494 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
495 loc, tdesc,
496 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
497 byteOffset));
498 // Convert base pointer (i64) to LLVM pointer type.
499 Value finalPtrLLVM =
500 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
501 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
502 Value src = adaptor.getValue();
503 // If store value is a scalar, get value from op instead of adaptor.
504 // Adaptor might have optimized away single element vector
505 if (src.getType().isIntOrFloat()) {
506 src = op.getValue();
507 }
508 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
509 if (!srcVecTy)
510 return rewriter.notifyMatchFailure(
511 op, "Expected store value to be a vector type.");
512 // Get flat vector type of integer type with matching element bit size.
513 VectorType newSrcVecTy =
514 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
515 if (srcVecTy != newSrcVecTy)
516 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
517 auto storeCacheControl =
518 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
519 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
520 op, finalPtrLLVM, src,
521 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
522 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
523 auto loadCacheControl =
524 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
525 VectorType resTy = cast<VectorType>(op.getValue().getType());
526 VectorType loadedTy =
527 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
528 Value load = xevm::BlockLoadOp::create(
529 rewriter, loc, loadedTy, finalPtrLLVM,
530 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
531 if (loadedTy != resTy)
532 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
533 rewriter.replaceOp(op, load);
534 } else {
535 return rewriter.notifyMatchFailure(
536 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
537 "descriptor rank == 1");
538 }
539 }
540 return success();
541 }
542};
543
544// Add a builder that creates
545// offset * elemByteSize + baseAddr
546static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
547 Location loc, Value baseAddr, Value offset,
548 int64_t elemByteSize) {
550 rewriter, loc, baseAddr.getType(), elemByteSize);
551 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
552 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
553 return newAddr;
554}
555
556template <typename OpType,
557 typename = std::enable_if_t<llvm::is_one_of<
558 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
559class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
560 using OpConversionPattern<OpType>::OpConversionPattern;
561 LogicalResult
562 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
563 ConversionPatternRewriter &rewriter) const override {
564 Value offset = adaptor.getOffsets();
565 if (!offset)
566 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
567 auto loc = op.getLoc();
568 auto ctxt = rewriter.getContext();
569 Value basePtrI64;
570 // Load result or Store valye Type can be vector or scalar.
571 Type valOrResTy;
572 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
573 valOrResTy =
574 this->getTypeConverter()->convertType(op.getResult().getType());
575 else
576 valOrResTy = adaptor.getValue().getType();
577 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
578 bool hasScalarVal = !valOrResVecTy;
579 int64_t elemBitWidth =
580 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
581 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
582 // Element type must be multiple of 8 bits.
583 if (elemBitWidth % 8 != 0)
584 return rewriter.notifyMatchFailure(
585 op, "Expected element type bit width to be multiple of 8.");
586 int64_t elemByteSize = elemBitWidth / 8;
587 // Default memory space is global.
588 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
589 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
590 // Base pointer can come from source (load) or dest (store).
591 // If they are memrefs, we use their memory space.
592 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
593 basePtrI64 = adaptor.getSource();
594 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
595 auto addrSpace = memRefTy.getMemorySpaceAsInt();
596 if (addrSpace != 0)
597 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
598 }
599 } else {
600 basePtrI64 = adaptor.getDest();
601 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
602 auto addrSpace = memRefTy.getMemorySpaceAsInt();
603 if (addrSpace != 0)
604 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
605 }
606 }
607 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
608 if (basePtrI64.getType() != rewriter.getI64Type()) {
609 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
610 basePtrI64);
611 }
612 Value mask = adaptor.getMask();
613 if (dyn_cast<VectorType>(offset.getType())) {
614 // Offset needs be scalar. Single element vector is converted to scalar
615 // by type converter.
616 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
617 } else {
618 // If offset is provided, we add them to the base pointer.
619 // Offset is in number of elements, we need to multiply by
620 // element byte size.
621 basePtrI64 =
622 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
623 }
624 // Convert base pointer (i64) to LLVM pointer type.
625 Value basePtrLLVM =
626 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
627
628 Value maskForLane;
629 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
630 if (maskVecTy) {
631 // Mask needs be scalar. Single element vector is converted to scalar by
632 // type converter.
633 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
634 } else
635 maskForLane = mask;
636 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
637 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
638 maskForLane, true, true);
639 // If mask is true,- then clause - load from memory and yield.
640 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
641 if (!hasScalarVal)
642 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
643 valOrResVecTy.getElementType());
644 Value loaded =
645 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
646 // Set cache control attribute on the load operation.
647 loaded.getDefiningOp()->setAttr(
648 "cache_control", xevm::LoadCacheControlAttr::get(
649 ctxt, translateLoadXeGPUCacheHint(
650 op.getL1Hint(), op.getL3Hint())));
651 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
652 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
653 // If mask is false - else clause -yield a vector of zeros.
654 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
655 TypedAttr eVal;
656 if (eTy.isFloat())
657 eVal = FloatAttr::get(eTy, 0.0);
658 else
659 eVal = IntegerAttr::get(eTy, 0);
660 if (hasScalarVal)
661 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
662 else
663 loaded = arith::ConstantOp::create(
664 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
665 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
666 rewriter.replaceOp(op, ifOp.getResult(0));
667 } else {
668 // If mask is true, perform the store.
669 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
670 auto body = ifOp.getBody();
671 rewriter.setInsertionPointToStart(body);
672 auto storeOp =
673 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
674 // Set cache control attribute on the store operation.
675 storeOp.getOperation()->setAttr(
676 "cache_control", xevm::StoreCacheControlAttr::get(
677 ctxt, translateStoreXeGPUCacheHint(
678 op.getL1Hint(), op.getL3Hint())));
679 rewriter.eraseOp(op);
680 }
681 return success();
682 }
683};
684
685class CreateMemDescOpPattern final
686 : public OpConversionPattern<xegpu::CreateMemDescOp> {
687public:
688 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
689 LogicalResult
690 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter) const override {
692
693 rewriter.replaceOp(op, adaptor.getSource());
694 return success();
695 }
696};
697
698template <typename OpType,
699 typename = std::enable_if_t<llvm::is_one_of<
700 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
701class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
702 using OpConversionPattern<OpType>::OpConversionPattern;
703 LogicalResult
704 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
705 ConversionPatternRewriter &rewriter) const override {
706
707 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
708 if (offsets.empty())
709 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
710
711 auto loc = op.getLoc();
712 auto ctxt = rewriter.getContext();
713 Value baseAddr32 = adaptor.getMemDesc();
714 Value mdescVal = op.getMemDesc();
715 // Load result or Store value Type can be vector or scalar.
716 Type dataTy;
717 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
718 Type resType = op.getResult().getType();
719 // Some transforms may leave unit dimension in the 2D vector, adaptors do
720 // not catch it for results.
721 if (auto vecType = dyn_cast<VectorType>(resType)) {
722 assert(llvm::count_if(vecType.getShape(),
723 [](int64_t d) { return d != 1; }) <= 1 &&
724 "Expected either 1D vector or nD with unit dimensions");
725 resType = VectorType::get({vecType.getNumElements()},
726 vecType.getElementType());
727 }
728 dataTy = resType;
729 } else
730 dataTy = adaptor.getData().getType();
731 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
732 if (!valOrResVecTy)
733 valOrResVecTy = VectorType::get(1, dataTy);
734
735 int64_t elemBitWidth =
736 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
737 // Element type must be multiple of 8 bits.
738 if (elemBitWidth % 8 != 0)
739 return rewriter.notifyMatchFailure(
740 op, "Expected element type bit width to be multiple of 8.");
741 int64_t elemByteSize = elemBitWidth / 8;
742
743 // Default memory space is SLM.
744 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
745 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
746
747 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
748
749 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
750 linearOffset = arith::IndexCastUIOp::create(
751 rewriter, loc, rewriter.getI32Type(), linearOffset);
752 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
753 linearOffset, elemByteSize);
754
755 // convert base pointer (i32) to LLVM pointer type
756 Value basePtrLLVM =
757 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
758
759 if (op.getSubgroupBlockIoAttr()) {
760 // if the attribute 'subgroup_block_io' is set to true, it lowers to
761 // xevm.blockload
762
763 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
764 VectorType intVecTy =
765 VectorType::get(valOrResVecTy.getShape(), intElemTy);
766
767 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
768 Value loadOp =
769 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
770 if (intVecTy != valOrResVecTy) {
771 loadOp =
772 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
773 }
774 rewriter.replaceOp(op, loadOp);
775 } else {
776 Value dataToStore = adaptor.getData();
777 if (valOrResVecTy != intVecTy) {
778 dataToStore =
779 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
780 }
781 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
782 nullptr);
783 rewriter.eraseOp(op);
784 }
785 return success();
786 }
787
788 if (valOrResVecTy.getNumElements() >= 1) {
789 auto chipOpt = xegpu::getChipStr(op);
790 if (!chipOpt ||
791 (*chipOpt != "pvc" && *chipOpt != "bmg" && *chipOpt != "cri")) {
792 // the lowering for chunk load only works for pvc, bmg or cri
793 return rewriter.notifyMatchFailure(
794 op, "The lowering is specific to pvc, bmg or cri.");
795 }
796 }
797
798 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
799 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
800 // operation. LLVM load/store does not support vector of size 1, so we
801 // need to handle this case separately.
802 auto scalarTy = valOrResVecTy.getElementType();
803 LLVM::LoadOp loadOp;
804 if (valOrResVecTy.getNumElements() == 1)
805 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
806 else
807 loadOp =
808 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
809 rewriter.replaceOp(op, loadOp);
810 } else {
811 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
812 rewriter.eraseOp(op);
813 }
814 return success();
815 }
816};
817
818class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
819 using OpConversionPattern::OpConversionPattern;
820 LogicalResult
821 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
822 ConversionPatternRewriter &rewriter) const override {
823 auto loc = op.getLoc();
824 auto ctxt = rewriter.getContext();
825 Value basePtrI64 = adaptor.getSource();
826 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
827 if (basePtrI64.getType() != rewriter.getI64Type())
828 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
829 basePtrI64);
830 Value offsets = adaptor.getOffsets();
831 if (offsets) {
832 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
833 if (offsetsVecTy) {
834 // Offset needs be scalar.
835 return rewriter.notifyMatchFailure(op,
836 "Expected offsets to be a scalar.");
837 } else {
838 int64_t elemBitWidth{0};
839 int64_t elemByteSize;
840 // Element byte size can come from two sources:
841 if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
842 // If memref is available, we use its element type to
843 // determine element byte size.
844 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
845 } else {
846 // Otherwise, we use the provided offset byte alignment.
847 elemByteSize = *op.getOffsetAlignByte();
848 }
849 if (elemBitWidth != 0) {
850 if (elemBitWidth % 8 != 0)
851 return rewriter.notifyMatchFailure(
852 op, "Expected element type bit width to be multiple of 8.");
853 elemByteSize = elemBitWidth / 8;
854 }
855 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
856 elemByteSize);
857 }
858 }
859 // Default memory space is global.
860 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
861 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
862 // If source is a memref, we use its memory space.
863 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
864 auto addrSpace = memRefTy.getMemorySpaceAsInt();
865 if (addrSpace != 0)
866 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
867 }
868 // Convert base pointer (i64) to LLVM pointer type.
869 Value ptrLLVM =
870 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
871 // Create the prefetch op with cache control attribute.
872 xevm::PrefetchOp::create(
873 rewriter, loc, ptrLLVM,
874 xevm::LoadCacheControlAttr::get(
875 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
876 rewriter.eraseOp(op);
877 return success();
878 }
879};
880
881class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
882 using OpConversionPattern::OpConversionPattern;
883 LogicalResult
884 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
885 ConversionPatternRewriter &rewriter) const override {
886 auto loc = op.getLoc();
887 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
888 switch (op.getFenceScope()) {
889 case xegpu::FenceScope::Workgroup:
890 memScope = xevm::MemScope::WORKGROUP;
891 break;
892 case xegpu::FenceScope::GPU:
893 memScope = xevm::MemScope::DEVICE;
894 break;
895 }
896 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
897 switch (op.getMemoryKind()) {
898 case xegpu::MemorySpace::Global:
899 addrSpace = xevm::AddrSpace::GLOBAL;
900 break;
901 case xegpu::MemorySpace::SLM:
902 addrSpace = xevm::AddrSpace::SHARED;
903 break;
904 }
905 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
906 rewriter.eraseOp(op);
907 return success();
908 }
909};
910
911static auto encodePrecision = [](Type type) -> xevm::ElemType {
912 if (type.isBF16())
913 return xevm::ElemType::BF16;
914 else if (type.isF16())
915 return xevm::ElemType::F16;
916 else if (type.isTF32())
917 return xevm::ElemType::TF32;
918 else if (type.isInteger(8)) {
919 if (type.isUnsignedInteger())
920 return xevm::ElemType::U8;
921 return xevm::ElemType::S8;
922 } else if (type.isF32())
923 return xevm::ElemType::F32;
924 else if (type.isInteger(32))
925 return xevm::ElemType::S32;
926 else if (type.isF8E5M2())
927 return xevm::ElemType::BF8;
928 else if (type.isF8E4M3FN())
929 return xevm::ElemType::F8;
930 else if (mlir::isa<Float4E2M1FNType>(type))
931 return xevm::ElemType::E2M1;
932 llvm_unreachable("add more support for ElemType");
933};
934
935static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
936 switch (pTy) {
937 case xevm::ElemType::TF32:
938 return 1;
939 case xevm::ElemType::BF16:
940 case xevm::ElemType::F16:
941 return 2;
942 case xevm::ElemType::U8:
943 case xevm::ElemType::S8:
944 case xevm::ElemType::F8:
945 case xevm::ElemType::BF8:
946 return 4;
947 case xevm::ElemType::E2M1:
948 return 8;
949 default:
950 llvm_unreachable("unsupported xevm::ElemType");
951 }
952}
953
954class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
955 using OpConversionPattern::OpConversionPattern;
956 LogicalResult
957 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
958 ConversionPatternRewriter &rewriter) const override {
959 auto loc = op.getLoc();
960 auto ctxt = rewriter.getContext();
961 auto aTy = cast<VectorType>(op.getLhs().getType());
962 auto bTy = cast<VectorType>(op.getRhs().getType());
963 auto resultType = cast<VectorType>(op.getResultType());
964
965 // get the correct dpasInst by getting info from chip
966 auto chipStr = xegpu::getChipStr(op);
967 if (!chipStr)
968 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
969
970 const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
971 if (!uArch)
972 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
973
974 auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
975 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
976 uArch->getInstruction(
977 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
978 if (!dpasInst)
979 return rewriter.notifyMatchFailure(op,
980 "DPAS not supported by target uArch");
981
982 auto checkSupportedTypes = [&](VectorType vecTy,
983 xegpu::uArch::MMAOpndKind kind) -> bool {
984 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
985 return llvm::find(supported, vecTy.getElementType()) != supported.end();
986 };
987
988 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
989 return rewriter.notifyMatchFailure(
990 op, "A-matrix element type not supported by target uArch");
991 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
992 return rewriter.notifyMatchFailure(
993 op, "B-matrix element type not supported by target uArch");
994 // NOTE: Supported types for MatrixC and MatrixD are identical
995 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
996 return rewriter.notifyMatchFailure(
997 op, "result/accumulator element type not supported by target uArch");
998
999 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1000 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1001 Value c = op.getAcc();
1002 if (!c) {
1003 auto elementTy = resultType.getElementType();
1004 Attribute initValueAttr;
1005 if (isa<FloatType>(elementTy))
1006 initValueAttr = FloatAttr::get(elementTy, 0.0);
1007 else
1008 initValueAttr = IntegerAttr::get(elementTy, 0);
1009 c = arith::ConstantOp::create(
1010 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
1011 }
1012
1013 Value aVec = op.getLhs();
1014 Value bVec = op.getRhs();
1015 auto cvecty = cast<VectorType>(c.getType());
1016 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
1017 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
1018 VectorType cNty =
1019 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
1020 if (cvecty != cNty)
1021 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
1022 Value dpasRes = xevm::MMAOp::create(
1023 rewriter, loc, cNty, aVec, bVec, c,
1024 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
1025 systolicDepth *
1026 getNumOperandsPerDword(precATy)),
1027 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1028 if (cvecty != cNty)
1029 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
1030 rewriter.replaceOp(op, dpasRes);
1031 return success();
1032 }
1033};
1034
1035static std::optional<LLVM::AtomicBinOp>
1036matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1037 switch (arithKind) {
1038 case arith::AtomicRMWKind::addf:
1039 return LLVM::AtomicBinOp::fadd;
1040 case arith::AtomicRMWKind::addi:
1041 return LLVM::AtomicBinOp::add;
1042 case arith::AtomicRMWKind::assign:
1043 return LLVM::AtomicBinOp::xchg;
1044 case arith::AtomicRMWKind::maximumf:
1045 return LLVM::AtomicBinOp::fmax;
1046 case arith::AtomicRMWKind::maxs:
1047 return LLVM::AtomicBinOp::max;
1048 case arith::AtomicRMWKind::maxu:
1049 return LLVM::AtomicBinOp::umax;
1050 case arith::AtomicRMWKind::minimumf:
1051 return LLVM::AtomicBinOp::fmin;
1052 case arith::AtomicRMWKind::mins:
1053 return LLVM::AtomicBinOp::min;
1054 case arith::AtomicRMWKind::minu:
1055 return LLVM::AtomicBinOp::umin;
1056 case arith::AtomicRMWKind::ori:
1057 return LLVM::AtomicBinOp::_or;
1058 case arith::AtomicRMWKind::andi:
1059 return LLVM::AtomicBinOp::_and;
1060 default:
1061 return std::nullopt;
1062 }
1063}
1064
1065class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1066 using OpConversionPattern::OpConversionPattern;
1067 LogicalResult
1068 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1069 ConversionPatternRewriter &rewriter) const override {
1070 auto loc = op.getLoc();
1071 auto ctxt = rewriter.getContext();
1072 auto tdesc = op.getTensorDesc().getType();
1073 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1074 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1075 Value basePtrI64 = arith::IndexCastOp::create(
1076 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1077 Value basePtrLLVM =
1078 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1079 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1080 VectorType srcOrDstFlatVecTy = VectorType::get(
1081 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1082 Value srcFlatVec = vector::ShapeCastOp::create(
1083 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1084 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1085 assert(atomicKind.has_value());
1086 Value resVec = srcFlatVec;
1087 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1088 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1089 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1090 rewriter.getIndexAttr(i));
1091 Value currPtr =
1092 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1093 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1094 Value newVal =
1095 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1096 val, LLVM::AtomicOrdering::seq_cst);
1097 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1098 }
1099 rewriter.replaceOp(op, resVec);
1100 return success();
1101 }
1102};
1103
1104class DpasMxToXeVMPattern : public OpConversionPattern<xegpu::DpasMxOp> {
1105 using OpConversionPattern::OpConversionPattern;
1106 LogicalResult
1107 matchAndRewrite(xegpu::DpasMxOp op, xegpu::DpasMxOp::Adaptor adaptor,
1108 ConversionPatternRewriter &rewriter) const override {
1109 auto loc = op.getLoc();
1110 auto ctxt = rewriter.getContext();
1111 auto aTy = op.getA().getType();
1112 auto bTy = op.getB().getType();
1113 auto resVecTy =
1114 cast<VectorType>(getTypeConverter()->convertType(op.getType()));
1115
1116 auto chipStr = xegpu::getChipStr(op);
1117 if (!chipStr)
1118 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
1119
1120 const auto *uArch = xegpu::uArch::getUArch(*chipStr);
1121 if (!uArch)
1122 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
1123
1124 // TODO: Add supported shape check
1125
1126 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1127 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1128 Value c = adaptor.getAcc();
1129 if (!c) {
1130 auto elementTy = resVecTy.getElementType();
1131 Attribute initValueAttr;
1132 if (isa<FloatType>(elementTy))
1133 initValueAttr = FloatAttr::get(elementTy, 0.0);
1134 else
1135 initValueAttr = IntegerAttr::get(elementTy, 0);
1136 c = arith::ConstantOp::create(
1137 rewriter, loc, DenseElementsAttr::get(resVecTy, initValueAttr));
1138 }
1139
1140 Value aVec = adaptor.getA();
1141 Value bVec = adaptor.getB();
1142 auto aVecTy = cast<VectorType>(aVec.getType());
1143 auto bVecTy = cast<VectorType>(bVec.getType());
1144 if (aVecTy.getElementTypeBitWidth() == 4)
1145 aVec = vector::BitCastOp::create(
1146 rewriter, loc,
1147 VectorType::get(aVecTy.getNumElements() / 2, rewriter.getI8Type()),
1148 aVec);
1149 if (bVecTy.getElementTypeBitWidth() == 4)
1150 bVec = vector::BitCastOp::create(
1151 rewriter, loc,
1152 VectorType::get(bVecTy.getNumElements() / 2, rewriter.getI8Type()),
1153 bVec);
1154 auto cVecTy = cast<VectorType>(c.getType());
1155 xevm::ElemType precCTy = encodePrecision(cVecTy.getElementType());
1156 xevm::ElemType precDTy = encodePrecision(resVecTy.getElementType());
1157 Value scaleA = adaptor.getScaleA();
1158 Value scaleB = adaptor.getScaleB();
1159 Value dpasMxRes = xevm::MMAMxOp::create(
1160 rewriter, loc, resVecTy, aVec, bVec, scaleA, scaleB, c,
1161 xevm::MMAShapeAttr::get(ctxt, cVecTy.getNumElements(), executionSize,
1162 systolicDepth *
1163 getNumOperandsPerDword(precATy)),
1164 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1165 rewriter.replaceOp(op, dpasMxRes);
1166 return success();
1167 }
1168};
1169
1170//===----------------------------------------------------------------------===//
1171// arith.extf / arith.truncf to xevm.extf / xevm.truncf
1172//===----------------------------------------------------------------------===//
1173//
1174// Micro-scaling (MX) GEMM lowering breaks arith.scaling_extf/scaling_truncf
1175// into plain arith.extf/arith.truncf whose narrow side uses one of the MX float
1176// formats (f8E5M2, f8E4M3FN or f4E2M1FN). These narrow floats have no native
1177// LLVM support, so the conversions are mapped onto the dedicated xevm.extf /
1178// xevm.truncf ops which lower to hardware builtins. The f8E8M0FNU scale type is
1179// intentionally not handled here: it is expanded into integer arithmetic by
1180// arith-expand before this pass runs.
1181
1182// xevm.extf / xevm.truncf only convert between the MX narrow floats and
1183// f16/bf16, and the underlying builtins operate on exactly 16 f16/bf16 values.
1184static constexpr int64_t kXeVMExtfTruncfNumElems = 16;
1185
1186// Maps a narrow MX float element type to the matching xevm.extf source enum.
1187static std::optional<xevm::ExtfSrcElemTypes> getExtfNarrowType(Type etype) {
1188 if (isa<Float8E5M2Type>(etype))
1189 return xevm::ExtfSrcElemTypes::BF8;
1190 if (isa<Float8E4M3FNType>(etype))
1191 return xevm::ExtfSrcElemTypes::F8;
1192 if (isa<Float4E2M1FNType>(etype))
1193 return xevm::ExtfSrcElemTypes::E2M1;
1194 return std::nullopt;
1195}
1196
1197// Maps a narrow MX float element type to the matching xevm.truncf dest enum.
1198static std::optional<xevm::TruncfDstElemTypes> getTruncfNarrowType(Type etype) {
1199 if (isa<Float8E5M2Type>(etype))
1200 return xevm::TruncfDstElemTypes::BF8;
1201 if (isa<Float8E4M3FNType>(etype))
1202 return xevm::TruncfDstElemTypes::F8;
1203 if (isa<Float4E2M1FNType>(etype))
1204 return xevm::TruncfDstElemTypes::E2M1;
1205 return std::nullopt;
1206}
1207
1208// Returns true if `op` is an arith.extf that can be lowered to xevm.extf, i.e.
1209// a rank-1 widening from an MX narrow float to a 16-element f16/bf16 vector.
1210static bool isXeVMExtf(arith::ExtFOp op) {
1211 auto srcTy = dyn_cast<VectorType>(op.getIn().getType());
1212 auto dstTy = dyn_cast<VectorType>(op.getType());
1213 if (!srcTy || !dstTy || srcTy.getRank() != 1 || dstTy.getRank() != 1)
1214 return false;
1215 if (dstTy.getNumElements() != kXeVMExtfTruncfNumElems)
1216 return false;
1217 Type dstETy = dstTy.getElementType();
1218 if (!dstETy.isF16() && !dstETy.isBF16())
1219 return false;
1220 return getExtfNarrowType(srcTy.getElementType()).has_value();
1221}
1222
1223// Returns true if `op` is an arith.truncf that can be lowered to xevm.truncf,
1224// i.e. a rank-1 truncation from a 16-element f16/bf16 vector to an MX narrow
1225// float.
1226static bool isXeVMTruncf(arith::TruncFOp op) {
1227 auto srcTy = dyn_cast<VectorType>(op.getIn().getType());
1228 auto dstTy = dyn_cast<VectorType>(op.getType());
1229 if (!srcTy || !dstTy || srcTy.getRank() != 1 || dstTy.getRank() != 1)
1230 return false;
1231 if (srcTy.getNumElements() != kXeVMExtfTruncfNumElems)
1232 return false;
1233 Type srcETy = srcTy.getElementType();
1234 if (!srcETy.isF16() && !srcETy.isBF16())
1235 return false;
1236 return getTruncfNarrowType(dstTy.getElementType()).has_value();
1237}
1238
1239class ExtfToXeVMPattern : public OpConversionPattern<arith::ExtFOp> {
1240 using OpConversionPattern::OpConversionPattern;
1241 LogicalResult
1242 matchAndRewrite(arith::ExtFOp op, OpAdaptor adaptor,
1243 ConversionPatternRewriter &rewriter) const override {
1244 if (!isXeVMExtf(op))
1245 return rewriter.notifyMatchFailure(op, "not a xevm.extf compatible extf");
1246 Location loc = op.getLoc();
1247 MLIRContext *ctx = op.getContext();
1248 auto srcVecTy = cast<VectorType>(op.getIn().getType());
1249 auto dstVecTy = cast<VectorType>(op.getType());
1250 xevm::ExtfSrcElemTypes srcEnum =
1251 *getExtfNarrowType(srcVecTy.getElementType());
1252 xevm::ExtfDstElemTypes dstEnum = dstVecTy.getElementType().isF16()
1253 ? xevm::ExtfDstElemTypes::F16
1254 : xevm::ExtfDstElemTypes::BF16;
1255 // The narrow float operand has already been type-converted to an integer
1256 // vector of the same bit width (i4 for fp4, i8 for fp8). xevm.extf takes
1257 // the values packed into an i8 vector, so re-pack fp4 (i4) operands.
1258 Value src = adaptor.getIn();
1259 auto convSrcTy = cast<VectorType>(src.getType());
1260 if (convSrcTy.getElementTypeBitWidth() == 4)
1261 src = vector::BitCastOp::create(
1262 rewriter, loc,
1263 VectorType::get(convSrcTy.getNumElements() / 2, rewriter.getI8Type()),
1264 src);
1265 Type resTy = getTypeConverter()->convertType(dstVecTy);
1266 Value res = xevm::ExtfOp::create(
1267 rewriter, loc, resTy, src, xevm::ExtfSrcElemTypeAttr::get(ctx, srcEnum),
1268 xevm::ExtfDstElemTypeAttr::get(ctx, dstEnum));
1269 rewriter.replaceOp(op, res);
1270 return success();
1271 }
1272};
1273
1274class TruncfToXeVMPattern : public OpConversionPattern<arith::TruncFOp> {
1275 using OpConversionPattern::OpConversionPattern;
1276 LogicalResult
1277 matchAndRewrite(arith::TruncFOp op, OpAdaptor adaptor,
1278 ConversionPatternRewriter &rewriter) const override {
1279 if (!isXeVMTruncf(op))
1280 return rewriter.notifyMatchFailure(op,
1281 "not a xevm.truncf compatible truncf");
1282 Location loc = op.getLoc();
1283 MLIRContext *ctx = op.getContext();
1284 auto srcVecTy = cast<VectorType>(op.getIn().getType());
1285 auto dstVecTy = cast<VectorType>(op.getType());
1286 xevm::TruncfSrcElemTypes srcEnum = srcVecTy.getElementType().isF16()
1287 ? xevm::TruncfSrcElemTypes::F16
1288 : xevm::TruncfSrcElemTypes::BF16;
1289 xevm::TruncfDstElemTypes dstEnum =
1290 *getTruncfNarrowType(dstVecTy.getElementType());
1291 // xevm.truncf produces the narrow floats packed into an i8 vector.
1292 int64_t numNarrowBits =
1293 dstVecTy.getNumElements() * dstVecTy.getElementTypeBitWidth();
1294 Type packedTy = VectorType::get(numNarrowBits / 8, rewriter.getI8Type());
1295 Value res =
1296 xevm::TruncfOp::create(rewriter, loc, packedTy, adaptor.getIn(),
1297 xevm::TruncfSrcElemTypeAttr::get(ctx, srcEnum),
1298 xevm::TruncfDstElemTypeAttr::get(ctx, dstEnum));
1299 // Re-shape to the type-converted result type (i4 vector for fp4).
1300 Type resTy = getTypeConverter()->convertType(dstVecTy);
1301 if (res.getType() != resTy)
1302 res = vector::BitCastOp::create(rewriter, loc, resTy, res);
1303 rewriter.replaceOp(op, res);
1304 return success();
1305 }
1306};
1307
1308//===----------------------------------------------------------------------===//
1309// Pass Definition
1310//===----------------------------------------------------------------------===//
1311
1312struct ConvertXeGPUToXeVMPass
1313 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1314 using Base::Base;
1315
1316 void runOnOperation() override {
1317 MLIRContext *context = &getContext();
1318
1319 // XeVM type converter is based on LLVM type converter with the
1320 // following customizations.
1321 // First, type conversion rules are added for xegpu custom types,
1322 // TensorDescType and MemDescType.
1323 // Second, MemRefType is lowered to single integer type
1324 // Third, VectorType of single element or 0D is converted to vector
1325 // element type. Otherwise, vector type is flatten to 1D.
1326 LowerToLLVMOptions options(context);
1327 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1328 LLVMTypeConverter typeConverter(context, options);
1329
1330 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1331 Type i32Type = IntegerType::get(context, 32);
1332 typeConverter.addConversion([&](VectorType type) -> Type {
1333 auto elemType = typeConverter.convertType(type.getElementType());
1334 // If the vector rank is 0 or has a single element, return the element
1335 unsigned rank = type.getRank();
1336 if (rank == 0 || type.getNumElements() == 1)
1337 return elemType;
1338 // Otherwise, convert the vector to a flat vector type.
1339 int64_t sum = llvm::product_of(type.getShape());
1340 return VectorType::get(sum, elemType);
1341 });
1342 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1343 if (type.getRank() == 1)
1344 return xevmIndexType;
1345 return VectorType::get(8, i32Type);
1346 });
1347 // SLM access related type conversions.
1348 // TODO: LLVM DLTI provides clean way of representing different pointer size
1349 // based on address space. Currently pointer size of SLM access is hard
1350 // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering
1351 // to use DLTI instead of use64bitIndex option used above.
1352
1353 // Convert MemDescType into i32 for SLM
1354 typeConverter.addConversion(
1355 [&](xegpu::MemDescType type) -> Type { return i32Type; });
1356
1357 typeConverter.addConversion([&](MemRefType type) -> Type {
1358 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1359 });
1360
1361 // LLVM type converter puts unrealized casts for the following cases:
1362 // add materialization casts to handle them.
1363
1364 // Materialization to convert memref to i64 or i32 depending on global/SLM
1365 // Applies only to target materialization.
1366 // Note: int type to memref materialization is not required as xegpu ops
1367 // currently do not produce memrefs as result.
1368 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1369 ValueRange inputs,
1370 Location loc) -> Value {
1371 if (inputs.size() != 1)
1372 return {};
1373 auto input = inputs.front();
1374 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1375 unsigned rank = memrefTy.getRank();
1376 Type indexType = builder.getIndexType();
1377
1378 int64_t intOffsets;
1379 SmallVector<int64_t> intStrides;
1380 Value addr;
1381 Value offset;
1382 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1383 ShapedType::isStatic(intOffsets)) {
1384 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1385 input);
1386 offset = arith::ConstantOp::create(builder, loc,
1387 builder.getIndexAttr(intOffsets));
1388 } else {
1389
1390 // Result types: [base_memref, offset, stride0, stride1, ...,
1391 // strideN-1, size0, size1, ..., sizeN-1]
1392 SmallVector<Type> resultTypes{
1393 MemRefType::get({}, memrefTy.getElementType(),
1394 MemRefLayoutAttrInterface(),
1395 memrefTy.getMemorySpace()),
1396 indexType};
1397 // strides + sizes
1398 resultTypes.append(2 * rank, indexType);
1399
1400 auto meta = memref::ExtractStridedMetadataOp::create(
1401 builder, loc, resultTypes, input);
1402
1403 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1404 builder, loc, meta.getBaseBuffer());
1405 offset = meta.getOffset();
1406 }
1407
1408 auto addrCasted =
1409 arith::IndexCastUIOp::create(builder, loc, type, addr);
1410 auto offsetCasted =
1411 arith::IndexCastUIOp::create(builder, loc, type, offset);
1412
1413 // Compute the final address: base address + byte offset
1414 auto byteSize = arith::ConstantOp::create(
1415 builder, loc, type,
1416 builder.getIntegerAttr(type,
1417 memrefTy.getElementTypeBitWidth() / 8));
1418 auto byteOffset =
1419 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1420 auto addrWithOffset =
1421 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1422
1423 return addrWithOffset.getResult();
1424 }
1425 return {};
1426 };
1427
1428 // Materialization to convert ui64 to i64
1429 // Applies only to target materialization.
1430 // Note: i64 to ui64 materialization is not required as xegpu ops
1431 // currently do not produce ui64 as result.
1432 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1433 ValueRange inputs,
1434 Location loc) -> Value {
1435 if (inputs.size() != 1)
1436 return {};
1437 auto input = inputs.front();
1438 if (input.getType() == builder.getIntegerType(64, false)) {
1439 Value cast =
1440 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1441 .getResult();
1442 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1443 .getResult();
1444 }
1445 return {};
1446 };
1447
1448 // Materialization to convert ui32 to i32
1449 // Applies only to target materialization.
1450 // Note: i32 to ui32 materialization is not required as xegpu ops
1451 // currently do not produce ui32 as result.
1452 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1453 ValueRange inputs,
1454 Location loc) -> Value {
1455 if (inputs.size() != 1)
1456 return {};
1457 auto input = inputs.front();
1458 if (input.getType() == builder.getIntegerType(32, false)) {
1459 Value cast =
1460 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1461 .getResult();
1462 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1463 .getResult();
1464 }
1465 return {};
1466 };
1467
1468 // Materialization to convert between vector types
1469 // - Add shape cast for different shapes
1470 // - Add bitcast for different element types
1471 // Applies to both source and target materialization.
1472 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1473 ValueRange inputs,
1474 Location loc) -> Value {
1475 if (inputs.size() != 1)
1476 return {};
1477 auto input = inputs.front();
1478 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1479 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1480 Value cast = input;
1481 // If the target type has a different shape, add a shape cast
1482 // If the target type has a different element type, add a bitcast
1483 if (targetVecTy.getShape() != vecTy.getShape()) {
1484 cast = vector::ShapeCastOp::create(
1485 builder, loc,
1486 VectorType::get(targetVecTy.getShape(),
1487 vecTy.getElementType()),
1488 cast)
1489 .getResult();
1490 }
1491 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1492 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1493 .getResult();
1494 }
1495 return cast;
1496 }
1497 }
1498 return {};
1499 };
1500
1501 // Materialization to convert
1502 // - single element vector to single element of vector element type
1503 // Applies only to target materialization.
1504 auto vectorToSingleElementMaterializationCast =
1505 [](OpBuilder &builder, Type type, ValueRange inputs,
1506 Location loc) -> Value {
1507 if (inputs.size() != 1)
1508 return {};
1509 auto input = inputs.front();
1510 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1511 // Source needs to be single element vector
1512 auto rank = vecTy.getRank();
1513 if (rank != 0 && vecTy.getNumElements() != 1)
1514 return {};
1515 auto inElemTy = vecTy.getElementType();
1516 // extract scalar
1517 Value cast = input;
1518 if (rank == 0) {
1519 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1520 } else {
1521 cast = vector::ExtractOp::create(builder, loc, cast,
1522 SmallVector<int64_t>(rank, 0))
1523 .getResult();
1524 }
1525 // Extracted element type may need conversion
1526 // Two cases
1527 // 1. Index type to integer type
1528 // 2. Other element type mismatch
1529 if (inElemTy.isIndex()) {
1530 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1531 .getResult();
1532 } else if (inElemTy != type) {
1533 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1534 }
1535 return cast;
1536 }
1537 return {};
1538 };
1539
1540 // Materialization to convert
1541 // - single element of vector element type to single element vector
1542 // If result type of original op is single element vector and lowered type
1543 // is scalar. This materialization cast creates a single element vector by
1544 // First convert element type if needed and then broadcast to single
1545 // element vector.
1546 // Applies only to source materialization.
1547 auto singleElementToVectorMaterializationCast =
1548 [](OpBuilder &builder, Type type, ValueRange inputs,
1549 Location loc) -> Value {
1550 if (inputs.size() != 1)
1551 return {};
1552 auto input = inputs.front();
1553 auto inTy = input.getType();
1554 if (!inTy.isIntOrFloat())
1555 return {};
1556 // If the target type is a vector of rank 0 or single element vector
1557 // of element type matching input type, broadcast input to target type.
1558 if (auto vecTy = dyn_cast<VectorType>(type)) {
1559 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1560 return {};
1561 auto outElemTy = vecTy.getElementType();
1562 Value cast = input;
1563 if (outElemTy.isIndex()) {
1564 cast = arith::IndexCastUIOp::create(builder, loc,
1565 builder.getIndexType(), cast)
1566 .getResult();
1567 } else if (inTy != outElemTy) {
1568 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1569 .getResult();
1570 }
1571 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1572 .getResult();
1573 }
1574 return {};
1575 };
1576 typeConverter.addSourceMaterialization(
1577 singleElementToVectorMaterializationCast);
1578 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1579 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1580 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1581 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1582 typeConverter.addTargetMaterialization(
1583 vectorToSingleElementMaterializationCast);
1584 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1585 ConversionTarget target(*context);
1586 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1587 vector::VectorDialect, arith::ArithDialect,
1588 memref::MemRefDialect, gpu::GPUDialect,
1589 index::IndexDialect>();
1590 target.addIllegalDialect<xegpu::XeGPUDialect>();
1591 // arith.extf/arith.truncf between MX narrow floats and f16/bf16 are routed
1592 // to xevm.extf/xevm.truncf; all other arith float casts stay legal.
1593 target.addDynamicallyLegalOp<arith::ExtFOp>(
1594 [](arith::ExtFOp op) { return !isXeVMExtf(op); });
1595 target.addDynamicallyLegalOp<arith::TruncFOp>(
1596 [](arith::TruncFOp op) { return !isXeVMTruncf(op); });
1597
1598 RewritePatternSet patterns(context);
1599 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1601 patterns, target);
1602 if (failed(applyPartialConversion(getOperation(), target,
1603 std::move(patterns))))
1604 signalPassFailure();
1605 }
1606};
1607} // namespace
1608
1609//===----------------------------------------------------------------------===//
1610// Pattern Population
1611//===----------------------------------------------------------------------===//
1613 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1614 patterns.add<CreateNdDescToXeVMPattern,
1615 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1616 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1617 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1618 typeConverter, patterns.getContext());
1619 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1620 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1621 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1622 typeConverter, patterns.getContext());
1623 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1624 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1625 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1626 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1627 patterns.getContext());
1628 patterns.add<DpasMxToXeVMPattern>(typeConverter, patterns.getContext());
1629 patterns.add<ExtfToXeVMPattern, TruncfToXeVMPattern>(typeConverter,
1630 patterns.getContext());
1631}
return success()
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:607
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:283
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)