15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "xegpu"
25 for (
size_t i = 0; i < trans.size(); i++)
26 shape[i] = old[trans[i]];
30 static std::string
makeString(T array,
bool breakline =
false) {
33 llvm::raw_string_ostream os(buf);
35 for (
size_t i = 1; i < array.size(); i++) {
36 os << array[i - 1] <<
", ";
40 os << array.back() <<
"]";
46 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
55 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
82 [[maybe_unused]]
auto ty = source.getType();
83 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
89 build(builder, state, tdesc, source, dynamicOffsets ,
96 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97 Type tdesc, TypedValue<IntegerType> source,
101 assert(shape.size() && offsets.size() && strides.size() &&
102 shape.size() == strides.size() && shape.size() == offsets.size());
115 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
119 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
124 auto rank = (int64_t)getMixedOffsets().size();
125 bool invalidRank =
false;
126 bool invalidElemTy =
false;
132 auto srcMemorySpace = getSourceMemorySpace();
133 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
134 if (srcMemorySpace != tdescMemorySpace)
135 return emitOpError(
"Memory space mismatch.")
136 <<
" Source: " << srcMemorySpace
137 <<
", TensorDesc: " << tdescMemorySpace;
141 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
143 invalidRank |= (memrefTy.getRank() != rank);
152 "Expecting the rank of shape, strides, offsets, and source (if source "
153 "is a memref) should match with each other.");
156 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
160 "Expecting the TensorDesc rank is up to 2 and not greater than the "
161 "ranks of shape, strides, offsets or the memref source.");
164 return emitOpError(
"TensorDesc should have the same element "
165 "type with the source if it is a memref.\n");
168 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
170 if (
getType().getRank() == 2 &&
171 tdescMemorySpace ==
static_cast<unsigned>(MemorySpace::SLM))
172 return emitOpError(
"SLM is not supported for 2D Block TensorDesc.\n");
181 auto tdescTy = getTensorDescType();
182 if (tdescTy.isScattered())
183 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
186 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
189 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
192 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
201 auto tdescTy = getTensorDescType();
204 if (tdescTy.getRank() > 2)
205 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
207 if (tdescTy.isScattered())
208 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
211 return emitOpError(
"Invalid result, it should be a VectorType.\n");
214 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
217 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
220 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
222 auto array_len = tdescTy.getArrayLength();
226 if (getTranspose()) {
227 auto trans = getTranspose().value();
230 bool valid = std::all_of(trans.begin(), trans.end(), [&](
int t) {
231 return t >= 0 && t < tdescTy.getRank();
237 emitWarning(
"Invalid transpose attr. It is ignored.");
241 if (tdescTy.getRank() == 2) {
243 auto vnni_factor = valueShape.back();
244 tdescShape[axis] /= vnni_factor;
245 tdescShape.push_back(vnni_factor);
247 emitWarning(
"Invalid Packed Attr. It is ignored (available for 2D "
248 "TensorDesc only).");
253 auto it = tdescShape.begin();
254 tdescShape.insert(it, array_len);
257 if (tdescShape != valueShape)
258 return emitOpError() <<
"Result shape doesn't match TensorDesc shape."
259 <<
"The expected shape is " <<
makeString(tdescShape)
260 <<
". But the given shape is "
269 auto dstTy = getTensorDescType();
270 auto valTy = getValueType();
272 if (dstTy.getRank() > 2)
273 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
275 if (dstTy.isScattered())
276 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
279 return emitOpError(
"Exepcting a VectorType result.\n");
282 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
285 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
288 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
297 auto ty = getTensorDescType();
298 if (ty.isScattered())
299 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
302 if (ty.getRank() != (int64_t)getNumOffsets()) {
303 return emitOpError(
"Invalid number of offsets.");
312 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
313 TensorDescType TensorDesc, Value source,
315 auto loc = source.getLoc();
316 int64_t size =
static_cast<int64_t
>(offsets.size());
319 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
320 build(builder, state, TensorDesc, source, offset);
323 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
324 TensorDescType TensorDesc, Value source,
327 build(builder, state, TensorDesc, source, ofrs);
331 auto tdescTy = getTensorDescType();
335 "Expecting the source is a 1D memref or pointer (uint64_t).");
337 if (!tdescTy.isScattered())
338 return emitOpError(
"Expects a scattered TensorDesc.\n");
344 auto srcMemorySpace = getSourceMemorySpace();
345 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
346 if (srcMemorySpace != tdescMemorySpace)
347 return emitOpError(
"Memory space mismatch.")
348 <<
" Source: " << srcMemorySpace
349 <<
", TensorDesc: " << tdescMemorySpace;
351 auto chunkSize = tdescTy.getChunkSize();
355 16, 32, 64, 128, 256};
356 if (!llvm::is_contained(supportedChunkSizes, chunkSize))
357 return emitOpError(
"Invalid chunk_size. Supported values are 1, 2, 3, 4, "
358 "8, 16, 32, 64, 128, or 256.");
361 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
362 auto bitsPerLane = elemBits * chunkSize;
363 if (chunkSize > 1 && bitsPerLane % 32) {
370 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
373 auto lscConstraints = 512 * 8;
374 if (elemBits * tdescTy.getNumElements() > lscConstraints)
375 return emitOpError(
"total access size (simd_lanes * chunk_size * "
376 "sizeof(elemTy)) is upto 512 bytes.");
378 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
380 shape.push_back(chunkSize);
383 if (shape != tdescShape)
384 return emitOpError(
"Incorrect TensorDesc shape. ")
385 <<
"Expected is " <<
makeString(shape) <<
"\n";
394 auto tdescTy = getTensorDescType();
395 if (!tdescTy.isScattered())
396 return emitOpError(
"Expects a scattered TensorDesc.\n");
399 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
402 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
405 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
414 auto tdescTy = getTensorDescType();
415 auto maskTy = getMaskType();
416 auto valueTy = getValueType();
418 if (!tdescTy.isScattered())
419 return emitOpError(
"Expects a scattered TensorDesc.\n");
422 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
425 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
428 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
430 auto tdescElemTy = tdescTy.getElementType();
432 if (tdescElemTy != valueElemTy)
434 "Value should have the same element type as TensorDesc.");
440 if (tdescShape[0] != maskShape[0])
441 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
443 if (tdescTy.getRank() == 2) {
444 if (!getTransposeAttr())
445 return emitOpError(
"load_gather has to be transposed.");
449 if (valueShape != tdescShape)
450 return emitOpError(
"Unexpected result shape")
451 <<
"(Expected shape: " <<
makeString(tdescShape)
452 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
461 auto tdescTy = getTensorDescType();
462 if (!tdescTy.isScattered())
463 return emitOpError(
"Expects a scattered TensorDesc.\n");
466 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
469 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
472 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
474 auto maskTy = getMaskType();
475 auto valueTy = getValueType();
479 if (tdescShape[0] != maskShape[0])
480 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
482 if (tdescTy.getRank() == 2) {
483 if (!getTransposeAttr())
484 return emitOpError(
"load_gather has to be transposed.");
488 if (valueShape != tdescShape)
489 return emitOpError(
"Unexpected value shape")
490 <<
"(Expected shape: " <<
makeString(tdescShape)
491 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
499 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
502 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
503 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
504 auto loc = tensorDesc.
getLoc();
505 int64_t size =
static_cast<int64_t
>(offsets.size());
508 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
509 build(builder, state, tdescTy, tensorDesc, offset);
512 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
515 build(builder, state, tensorDesc, ofrs);
522 int64_t lhsRank = getLhsType().getRank();
523 int64_t rhsRank = getRhsType().getRank();
525 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
526 return emitOpError(
"expecting lhs to be a 2D vector, and rhs to be either "
527 "2D or 3D (packed) vector.");
529 auto lhsShape = getLhsType().getShape();
530 auto rhsShape = getRhsType().getShape();
531 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
532 if (bK != lhsShape[1])
533 return emitOpError(
"K-dimension mismatch.");
541 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
542 #define GET_OP_CLASSES
543 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static std::string makeString(T array, bool breakline=false)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.