Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/nn/linear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BaseLinear : public Module {
Tensor gidx() const { return gidx_; }

std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization() const { return quantization_; }
void process_weights_after_loading();

protected:
// Parameters
Expand Down
5 changes: 5 additions & 0 deletions include/infinicore/nn/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Module {
public:
Module() = default;

virtual ~Module() = default;

const std::unordered_map<std::string, Parameter> &state_dict() const;

void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);
Expand All @@ -23,6 +25,8 @@ class Module {

void load_parameter_from_blob(const std::string &name, const void *data);

std::unordered_map<std::string, Module *> modules_dict() const;

protected:
Tensor register_parameter(const std::string &name, Parameter param);

Expand Down Expand Up @@ -83,6 +87,7 @@ class Module {
private:
void load_state_dict_recursively(const std::unordered_map<std::string, Tensor> &_state_dict, const std::string &prefix = "");
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
void collect_all_modules(std::unordered_map<std::string, Module *> &out, const std::string &prefix) const;
};

// ============================================================================
Expand Down
1 change: 1 addition & 0 deletions include/infinicore/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "quantization/awq.hpp"
#include "quantization/base_quantization.hpp"
#include "quantization/compressed_tensors.hpp"
#include "quantization/gptq.hpp"
#include "quantization/gptq_qy.hpp"
#include "quantization/none_quantizaiton.hpp"
#include "quantization/quantization_scheme.hpp"
2 changes: 2 additions & 0 deletions include/infinicore/quantization/base_quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class BaseQuantization {
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
virtual ~BaseQuantization() = default;

const nlohmann::json &get_config() const { return quant_config_; }

virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
template <typename T>
T get(const std::string &key) const {
Expand Down
30 changes: 30 additions & 0 deletions include/infinicore/quantization/gptq.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once
#include "base_quantization.hpp"
namespace infinicore::quantization {

class GPTQ : public BaseQuantization {
// This is a temporary class that currently only returns GPTQ W4A16.
// Future enhancements should parse quant_config to extract detailed quantization
// information and support multiple quantization schemes.
public:
explicit GPTQ(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};

infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::GPTQ_W4A16;
};

int get_packing_num() const {
// For GPTQ, we pack 8 int4 weights into a single int32 value.
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
}

int get_group_size() const {
// For simplicity, we return a fixed group size here. In a more complete implementation,
// this could be extracted from quant_config_ to support different group sizes.
return this->get_or<int>("group_size", 128); // Standard GPTQ group size
}
};

} // namespace infinicore::quantization
Loading
Loading